Skip to content

Commit

Permalink
Merge pull request #104 from Filip3314/voice_channels
Browse files Browse the repository at this point in the history
Added support for creating and joining voice channels
  • Loading branch information
Sergeileduc authored Apr 29, 2023
2 parents e6d02ac + b38884c commit a78c572
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 7 deletions.
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ invoke
sphinx-automodapi
build
flake8~=6.0.0
pynacl
32 changes: 31 additions & 1 deletion discord/ext/test/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import re
import typing
import datetime

import discord
import discord.http as dhttp
import pathlib
Expand Down Expand Up @@ -108,6 +109,9 @@ async def create_channel(
channel = make_text_channel(name, guild, permission_overwrites=perms, parent_id=parent_id)
elif channel_type == discord.ChannelType.category.value:
channel = make_category_channel(name, guild, permission_overwrites=perms)
elif channel_type == discord.ChannelType.voice.value:
channel = make_voice_channel(name, guild, permission_overwrites=perms)

else:
raise NotImplementedError(
"Operation occurred that isn't captured by the tests framework. This is dpytest's fault, please report"
Expand All @@ -124,6 +128,8 @@ async def delete_channel(self, channel_id: int, *, reason: str = None) -> None:
for sub_channel in channel.text_channels:
delete_channel(sub_channel)
delete_channel(channel)
if channel.type.value == discord.ChannelType.voice.value:
delete_channel(channel)

async def get_channel(self, channel_id: int) -> _types.JsonDict:
await callbacks.dispatch_event("get_channel", channel_id)
Expand Down Expand Up @@ -330,11 +336,13 @@ async def change_my_nickname(self, guild_id: int, nickname: str, *,
return {"nick": nickname}

async def edit_member(self, guild_id: int, user_id: int, *, reason: typing.Optional[str] = None,
**fields: typing.Any) -> None:
**fields: typing.Any) -> _types.JsonDict:
locs = _get_higher_locs(1)
member = locs.get("self", None)

await callbacks.dispatch_event("edit_member", fields, member, reason=reason)
member = update_member(member, nick=fields.get('nick'), roles=fields.get('roles'))
return facts.dict_from_member(member)

async def get_member(self, guild_id: int, member_id: int) -> _types.JsonDict:
locs = _get_higher_locs(1)
Expand Down Expand Up @@ -735,6 +743,28 @@ def make_category_channel(
return guild.get_channel(c_dict["id"])


def make_voice_channel(
name: str,
guild: discord.Guild,
position: int = -1,
id_num: int = -1,
permission_overwrites: typing.Optional[_types.JsonDict] = None,
parent_id: typing.Optional[int] = None,
bitrate: int = 192,
user_limit: int = 0

) -> discord.VoiceChannel:
if position == -1:
position = len(guild.voice_channels) + 1
c_dict = facts.make_voice_channel_dict(name, id_num, position=position, guild_id=guild.id,
permission_overwrites=permission_overwrites, parent_id=parent_id,
bitrate=bitrate, user_limit=user_limit)
state = get_state()
state.parse_channel_create(c_dict)

return guild.get_channel(c_dict["id"])


def delete_channel(channel: _types.AnyChannel) -> None:
c_dict = facts.make_text_channel_dict(channel.name, id_num=channel.id, guild_id=channel.guild.id)

Expand Down
14 changes: 14 additions & 0 deletions discord/ext/test/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import discord
from . import _types


generated_ids: int = 0


Expand Down Expand Up @@ -265,6 +266,10 @@ def make_dm_channel_dict(user: discord.User, id_num: int = -1, **kwargs: typing.
return make_channel_dict(discord.ChannelType.private, id_num, recipients=[dict_from_user(user)], **kwargs)


def make_voice_channel_dict(name: str, id_num: int = -1, **kwargs: typing.Any) -> _types.JsonDict:
return make_channel_dict(discord.ChannelType.voice.value, id_num, name=name, **kwargs)


def dict_from_overwrite(target: typing.Union[discord.Member, discord.Role],
overwrite: discord.PermissionOverwrite) -> _types.JsonDict:
allow, deny = overwrite.pair()
Expand Down Expand Up @@ -303,6 +308,15 @@ def dict_from_channel(channel: _types.AnyChannel) -> _types.JsonDict:
'permission_overwrites': [dict_from_overwrite(k, v) for k, v in channel.overwrites.items()],
'type': channel.type
}
if isinstance(channel, discord.VoiceChannel):
return {
'name': channel.name,
'position': channel.position,
'id': channel.id,
'guild_id': channel.guild.id,
'permission_overwrites': [dict_from_overwrite(k, v) for k, v in channel.overwrites.items()],
'type': channel.type
}


@typing.overload
Expand Down
35 changes: 29 additions & 6 deletions discord/ext/test/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
:mod:`discord.ext.test.verify`
"""


import sys
import asyncio
import logging
Expand Down Expand Up @@ -49,11 +48,13 @@ def require_config(func: typing.Callable[..., _types.T]) -> typing.Callable[...,
:param func: Function to decorate
:return: Function with added check for configuration being setup
"""

def wrapper(*args, **kwargs):
if _cur_config is None:
log.error("Attempted to make call before runner configured")
raise RuntimeError(f"Configure runner before calling {func.__name__}")
return func(*args, **kwargs)

wrapper.__wrapped__ = func
wrapper.__annotations__ = func.__annotations__
wrapper.__doc__ = func.__doc__
Expand Down Expand Up @@ -141,6 +142,22 @@ async def _message_callback(message: discord.Message) -> None:
await sent_queue.put(message)


async def _edit_member_callback(fields: typing.Any, member: discord.Member, reason: typing.Optional[str]):
"""
Internal callback. Updates a guild's voice states to reflect the given Member connecting to the given channel.
Other updates to members are handled in http.edit_member().
:param fields: Fields passed in from Member.edit().
:param member: The Member to edit.
:param reason: The reason for editing. Not used.
"""
data = {'user_id': member.id}
guild = member.guild
channel = fields.get('channel_id')
if not fields.get('nick') and not fields.get('roles'):
guild._update_voice_state(data, channel)


counter = count(0)


Expand Down Expand Up @@ -331,13 +348,15 @@ def get_config() -> RunnerConfig:
return _cur_config


def configure(client: discord.Client, num_guilds: int = 1, num_channels: int = 1, num_members: int = 1) -> None:
def configure(client: discord.Client, num_guilds: int = 1, num_text_channels: int = 1,
num_voice_channels: int = 1, num_members: int = 1) -> None:
"""
Set up the runner configuration. This should be done before any tests are run.
:param client: Client to configure with. Should be the bot/client that is going to be tested.
:param num_guilds: Number of guilds to start the configuration with. Default is 1
:param num_channels: Number of text channels in each guild to start with. Default is 1
:param num_text_channels: Number of text channels in each guild to start with. Default is 1
:param num_voice_channels: Number of voice channels in each guild to start with. Default is 1.
:param num_members: Number of members in each guild (other than the client) to start with. Default is 1.
"""

Expand Down Expand Up @@ -367,6 +386,7 @@ async def on_command_error(ctx, error):

# Configure global callbacks
callbacks.set_callback(_message_callback, "send_message")
callbacks.set_callback(_edit_member_callback, "edit_member")

back.get_state().stop_dispatch()

Expand All @@ -378,11 +398,14 @@ async def on_command_error(ctx, error):
channels = []
members = []
for guild in guilds:
for num in range(num_channels):
channel = back.make_text_channel(f"Channel_{num}", guild)
for num in range(num_text_channels):
channel = back.make_text_channel(f"TextChannel_{num}", guild)
channels.append(channel)
for num in range(num_voice_channels):
channel = back.make_voice_channel(f"VoiceChannel_{num}", guild)
channels.append(channel)
for num in range(num_members):
user = back.make_user(f"TestUser{str(num)}", f"{num+1:04}")
user = back.make_user(f"TestUser{str(num)}", f"{num + 1:04}")
member = back.make_member(user, guild, nick=user.name + f"_{str(num)}_nick")
members.append(member)
back.make_member(back.get_state().user, guild, nick=client.user.name + "_nick")
Expand Down
27 changes: 27 additions & 0 deletions discord/ext/test/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from . import factories as facts
from . import backend as back
from .voice import FakeVoiceChannel


class FakeState(dstate.ConnectionState):
Expand All @@ -37,6 +38,7 @@ def __init__(self, client: discord.Client, http: dhttp.HTTPClient, user: discord
self.shard_count = client.shard_count
self._get_websocket = lambda x: client.ws
self._do_dispatch = True
self._get_client = lambda: client

real_disp = self.dispatch

Expand Down Expand Up @@ -73,3 +75,28 @@ def _guild_needs_chunking(self, guild: discord.Guild):
Prevents chunking which can throw asyncio wait_for errors with tests under 60 seconds
"""
return False

def parse_channel_create(self, data) -> None:
"""
Need to make sure that FakeVoiceChannels are created when this is called to create VoiceChannels. Otherwise,
guilds would not be set up correctly.
:param data: info to use in channel creation.
"""
if data['type'] == discord.ChannelType.voice.value:
factory, ch_type = FakeVoiceChannel, discord.ChannelType.voice.value
else:
factory, ch_type = discord.channel._channel_factory(data['type'])

if factory is None:
return

guild_id = discord.utils._get_as_snowflake(data, 'guild_id')
guild = self._get_guild(guild_id)
if guild is not None:
# the factory can't be a DMChannel or GroupChannel here
channel = factory(guild=guild, state=self, data=data) # type: ignore
guild._add_channel(channel) # type: ignore
self.dispatch('guild_channel_create', channel)
else:
return
35 changes: 35 additions & 0 deletions discord/ext/test/voice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Callable

from discord import Client, VoiceClient
from discord.abc import Connectable, T
from discord.channel import VoiceChannel


class FakeVoiceClient(VoiceClient):
"""
Mock implementation of a Discord VoiceClient. VoiceClient.connect tries to contact the Discord API and is called
whenever connect() is called on a VoiceChannel, so we need to override that method and pass in the fake version
to prevent the program from actually making calls to the Discord API.
"""
async def connect(self, *, reconnect: bool, timeout: float, self_deaf: bool = False,
self_mute: bool = False) -> None:
self._connected.set()


class FakeVoiceChannel(VoiceChannel):
"""
Mock implementation of a Discord VoiceChannel. Exists just to pass a FakeVoiceClient into the superclass connect()
method.
"""

async def connect(
self,
*,
timeout: float = 60.0,
reconnect: bool = True,
cls: Callable[[Client, Connectable], T] = FakeVoiceClient,
self_deaf: bool = False,
self_mute: bool = False,
) -> T:
return await super().connect(timeout=timeout, reconnect=reconnect, cls=cls, self_deaf=self_deaf,
self_mute=self_mute)
28 changes: 28 additions & 0 deletions tests/test_create_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
import discord
import discord.ext.test as dpytest


@pytest.mark.asyncio
async def test_create_voice_channel(bot):
guild = bot.guilds[0]
http = bot.http

# create_channel checks the value of variables in the parent call context, so we need to set these for it to work
self = guild # noqa: F841
name = "voice_channel_1"
channel = await http.create_channel(guild, channel_type=discord.ChannelType.voice.value)
assert channel['type'] == discord.ChannelType.voice
assert channel['name'] == name


@pytest.mark.asyncio
async def test_make_voice_channel(bot):
guild = bot.guilds[0]
bitrate = 100
user_limit = 5
channel = dpytest.backend.make_voice_channel("voice", guild, bitrate=bitrate, user_limit=user_limit)
assert channel.name == "voice"
assert channel.guild == guild
assert channel.bitrate == bitrate
assert channel.user_limit == user_limit
29 changes: 29 additions & 0 deletions tests/test_voice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest


@pytest.mark.asyncio
async def test_bot_join_voice(bot):
assert not bot.voice_clients
await bot.guilds[0].voice_channels[0].connect()
assert bot.voice_clients


@pytest.mark.asyncio
async def test_bot_leave_voice(bot):
voice_client = await bot.guilds[0].voice_channels[0].connect()
await voice_client.disconnect()
assert not bot.voice_clients


@pytest.mark.asyncio
async def test_move_member(bot):
guild = bot.guilds[0]
voice_channel = guild.voice_channels[0]
member = guild.members[0]

assert member.voice is None
await member.move_to(voice_channel)
assert member.voice.channel == voice_channel

await member.move_to(None)
assert member.voice is None

0 comments on commit a78c572

Please sign in to comment.