Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] [commands] custom default arguments #1849

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions discord/ext/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
from .converter import *
from .cooldowns import *
from .cog import *
from .default import CustomDefault
28 changes: 22 additions & 6 deletions discord/ext/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .errors import *
from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping
from . import converter as converters
from . import default as defaults
from ._types import _BaseCommand
from .cog import Cog

Expand Down Expand Up @@ -632,12 +633,27 @@ async def do_conversion(self, ctx, converter, argument, param):
def _get_converter(self, param):
converter = param.annotation
if converter is param.empty:
if param.default is not param.empty:
converter = str if param.default is None else type(param.default)
else:
if param.default is param.empty or param.default is None:
converter = str
elif (inspect.isclass(param.default) and issubclass(param.default, defaults.CustomDefault)) or isinstance(param.default, defaults.CustomDefault):
converter = typing.Union[param.default.converters]
else:
converter = type(param.default)
return converter

async def _resolve_default(self, ctx, param):
try:
if inspect.isclass(param.default) and issubclass(param.default, defaults.CustomDefault):
instance = param.default()
return await instance.default(ctx=ctx, param=param)
elif isinstance(param.default, defaults.CustomDefault):
return await param.default.default(ctx=ctx, param=param)
except CommandError as e:
raise e
except Exception as e:
raise ConversionError(param.default, e) from e
return param.default

async def transform(self, ctx, param):
required = param.default is param.empty
converter = self._get_converter(param)
Expand Down Expand Up @@ -665,7 +681,7 @@ async def transform(self, ctx, param):
if self._is_typing_optional(param.annotation):
return None
raise MissingRequiredArgument(param)
return param.default
return await self._resolve_default(ctx, param)

previous = view.index
if consume_rest_is_special:
Expand Down Expand Up @@ -694,7 +710,7 @@ async def _transform_greedy_pos(self, ctx, param, required, converter):
result.append(value)

if not result and not required:
return param.default
return await self._resolve_default(ctx, param)
return result

async def _transform_greedy_var_pos(self, ctx, param, converter):
Expand All @@ -707,7 +723,7 @@ async def _transform_greedy_var_pos(self, ctx, param, converter):
view.index = previous
raise RuntimeError() from None # break loop
else:
return value
return value or await self._resolve_default(ctx, param)

@property
def clean_params(self) -> Dict[str, inspect.Parameter]:
Expand Down
106 changes: 106 additions & 0 deletions discord/ext/commands/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-

"""
The MIT License (MIT)

Copyright (c) 2015-2019 Rapptz

Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""

import discord

from .errors import MissingRequiredArgument

__all__ = (
'CustomDefault',
'Author',
'CurrentChannel',
'CurrentGuild',
'Call',
)

class CustomDefaultMeta(type):
def __new__(cls, *args, **kwargs):
name, bases, attrs = args
attrs['display'] = kwargs.pop('display', name)
return super().__new__(cls, name, bases, attrs, **kwargs)

def __repr__(cls):
return str(cls)

def __str__(cls):
return cls.display

class CustomDefault(metaclass=CustomDefaultMeta):
"""The base class of custom defaults that require the :class:`.Context`.

Classes that derive from this should override the :attr:`~.CustomDefault.converters` attribute to specify
converters to use and the :meth:`~.CustomDefault.default` method to do its conversion logic.
This method must be a coroutine.
"""
converters = (str,)

async def default(self, ctx, param):
"""|coro|

The method to override to do conversion logic.

If an error is found while converting, it is recommended to
raise a :exc:`.CommandError` derived exception as it will
properly propagate to the error handlers.

Parameters
-----------
ctx: :class:`.Context`
The invocation context that the argument is being used in.
"""
raise NotImplementedError('Derived classes need to implement this.')


class Author(CustomDefault):
"""Default parameter which returns the author for this context."""
converters = (discord.Member, discord.User)

async def default(self, ctx, param):
return ctx.author

class CurrentChannel(CustomDefault):
"""Default parameter which returns the channel for this context."""
converters = (discord.TextChannel,)

async def default(self, ctx, param):
return ctx.channel

class CurrentGuild(CustomDefault):
"""Default parameter which returns the guild for this context."""

async def default(self, ctx, param):
if ctx.guild:
return ctx.guild
raise MissingRequiredArgument(param)

class Call(CustomDefault):
"""Easy wrapper for lambdas/inline defaults."""

def __init__(self, callback):
self._callback = callback

async def default(self, ctx, param):
return self._callback(ctx, param)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it would be good if this supported async functions as well, although it isn't a strong enough opinion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree. In order to use this with a new async function you'd have to define it outside of the command signature anyway. At that point just make a Default class. If you want this to support async functions, I suggest you file a PEP for async lambdas :^)

16 changes: 16 additions & 0 deletions docs/ext/commands/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,22 @@ Converters

.. autoclass:: discord.ext.commands.Greedy()

.. _ext_commands_api_custom_default:

Default Parameters
-------------------

.. autoclass:: discord.ext.commands.CustomDefault
:members:

.. autoclass:: discord.ext.commands.default.Author

.. autoclass:: discord.ext.commands.default.CurrentChannel

.. autoclass:: discord.ext.commands.default.CurrentGuild

.. autoclass:: discord.ext.commands.default.Call

.. _ext_commands_api_errors:

Exceptions
Expand Down
60 changes: 60 additions & 0 deletions docs/ext/commands/commands.rst
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,66 @@ handlers that allow us to do just that. First we decorate an error handler funct
The first parameter of the error handler is the :class:`.Context` while the second one is an exception that is derived from
:exc:`~ext.commands.CommandError`. A list of errors is found in the :ref:`ext_commands_api_errors` page of the documentation.


Custom Defaults
---------------

Custom defaults allow us to specify :class:`.Context`-based defaults. Custom defaults are always classes which inherit from
:class:`.CustomDefault`.

The library provides some simple default implementations in ``ext.commands.default`` - :class:`.default.Author`, :class:`.default.CurrentChannel`,
and :class:`.default.CurrentGuild` returning the corresponding properties from the Context. These can be used along with Converters to
simplify your individual commands. You can also use :class:`.default.Call` to quickly wrap existing functions.

A DefaultParam returning ``None`` is valid - if this should be an error, raise :class:`.MissingRequiredArgument`.

.. code-block:: python3
khazhyk marked this conversation as resolved.
Show resolved Hide resolved
:emphasize-lines: 14,17,32

class Image(Converter):
"""Find images associated with the message."""

async def convert(self, ctx, argument):
if argument.startswith("http://") or argument.startswith("https://"):
return argument

member = await MemberConverter().convert(ctx, argument)
if member:
return str(member.avatar_url_as(format="png"))

raise errors.BadArgument(f"{argument} isn't a member or url.")

class LastImage(CustomDefault):
"""Default param which finds the last image in chat."""
converters = (AnyImage,)

async def default(self, ctx, param):
for attachment in message.attachments:
if attachment.proxy_url:
return attachment.proxy_url
async for message in ctx.history(ctx, limit=100):
for embed in message.embeds:
if embed.thumbnail and embed.thumbnail.proxy_url:
return embed.thumbnail.proxy_url
for attachment in message.attachments:
if attachment.proxy_url:
return attachment.proxy_url

raise errors.MissingRequiredArgument(param)

@bot.command()
async def echo_image(ctx, *, image=LastImage):
async with aiohttp.ClientSession() as sess:
async with sess.get(image) as resp:
resp.raise_for_status()
my_bytes = io.BytesIO(await resp.content.read())
await ctx.send(file=discord.File(filename="your_image", fp=my_bytes))

.. tip:: You can change the name of a Custom Default that is displayed in help command by passing ``display`` meta option. Using previous example: ``class LastImage(CustomDefault, display='last image from chat')``




Checks
-------

Expand Down