Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V3 Config] Adjust functionality of get_attr #1342

Merged
merged 6 commits into from
Feb 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions redbot/cogs/customcom/customcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_now(self) -> str:
async def get(self,
message: discord.Message,
command: str) -> str:
ccinfo = await self.db(message.guild).commands.get_attr(command)
ccinfo = await self.db(message.guild).commands.get_raw(command, default=None)
if not ccinfo:
raise NotFound
else:
Expand All @@ -82,7 +82,7 @@ async def create(self,
response):
"""Create a customcommand"""
# Check if this command is already registered as a customcommand
if await self.db(ctx.guild).commands.get_attr(command):
if await self.db(ctx.guild).commands.get_raw(command, default=None):
raise AlreadyExists()
author = ctx.message.author
ccinfo = {
Expand All @@ -96,20 +96,20 @@ async def create(self,
'response': response

}
await self.db(ctx.guild).commands.set_attr(command,
ccinfo)
await self.db(ctx.guild).commands.set_raw(
command, value=ccinfo)

async def edit(self,
ctx: commands.Context,
command: str,
response: None):
"""Edit an already existing custom command"""
# Check if this command is registered
if not await self.db(ctx.guild).commands.get_attr(command):
if not await self.db(ctx.guild).commands.get_raw(command, default=None):
raise NotFound()

author = ctx.message.author
ccinfo = await self.db(ctx.guild).commands.get_attr(command)
ccinfo = await self.db(ctx.guild).commands.get_raw(command, default=None)

def check(m):
return m.channel == ctx.channel and m.author == ctx.message.author
Expand Down Expand Up @@ -138,18 +138,18 @@ def check(m):
author.id
)

await self.db(ctx.guild).commands.set_attr(command,
ccinfo)
await self.db(ctx.guild).commands.set_raw(
command, value=ccinfo)

async def delete(self,
ctx: commands.Context,
command: str):
"""Delete an already exisiting custom command"""
# Check if this command is registered
if not await self.db(ctx.guild).commands.get_attr(command):
if not await self.db(ctx.guild).commands.get_raw(command, default=None):
raise NotFound()
await self.db(ctx.guild).commands.set_attr(command,
None)
await self.db(ctx.guild).commands.set_raw(
command, value=None)


class CustomCommands:
Expand Down Expand Up @@ -326,7 +326,7 @@ async def on_message(self,
return

guild = message.guild
prefixes = await self.bot.db.guild(message.guild).get_attr('prefix')
prefixes = await self.bot.db.guild(guild).get_raw('prefix', default=[])

if len(prefixes) < 1:
def_prefixes = await self.bot.get_prefix(message)
Expand Down
10 changes: 5 additions & 5 deletions redbot/cogs/streams/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def _initialize_lists(self):
@commands.command()
async def twitch(self, ctx, channel_name: str):
"""Checks if a Twitch channel is streaming"""
token = await self.db.tokens.get_attr(TwitchStream.__name__)
token = await self.db.tokens.get_raw(TwitchStream.__name__, default=None)
stream = TwitchStream(name=channel_name,
token=token)
await self.check_online(ctx, stream)
Expand Down Expand Up @@ -187,7 +187,7 @@ async def streamalert_list(self, ctx):
async def stream_alert(self, ctx, _class, channel_name):
stream = self.get_stream(_class, channel_name.lower())
if not stream:
token = await self.db.tokens.get_attr(_class.__name__)
token = await self.db.tokens.get_raw(_class.__name__, default=None)
stream = _class(name=channel_name,
token=token)
try:
Expand All @@ -210,7 +210,7 @@ async def stream_alert(self, ctx, _class, channel_name):
async def community_alert(self, ctx, _class, community_name):
community = self.get_community(_class, community_name)
if not community:
token = await self.db.tokens.get_attr(_class.__name__)
token = await self.db.tokens.get_raw(_class.__name__, default=None)
community = _class(name=community_name, token=token)
try:
await community.get_community_streams()
Expand Down Expand Up @@ -477,7 +477,7 @@ async def load_streams(self):
if not _class:
continue

token = await self.db.tokens.get_attr(_class.__name__)
token = await self.db.tokens.get_raw(_class.__name__)
streams.append(_class(token=token, **raw_stream))

# issue 1191 extended resolution: Remove this after suitable period
Expand All @@ -497,7 +497,7 @@ async def load_communities(self):
if not _class:
continue

token = await self.db.tokens.get_attr(_class.__name__)
token = await self.db.tokens.get_raw(_class.__name__, default=None)
communities.append(_class(token=token, **raw_community))

# issue 1191 extended resolution: Remove this after suitable period
Expand Down
59 changes: 16 additions & 43 deletions redbot/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,17 +267,13 @@ def is_value(self, item: str) -> bool:

return not isinstance(default, dict)

def get_attr(self, item: str, default=None, resolve=True):
def get_attr(self, item: str):
"""Manually get an attribute of this Group.

This is available to use as an alternative to using normal Python
attribute access. It is required if you find a need for dynamic
attribute access. It may be required if you find a need for dynamic
attribute access.

Note
----
Use of this method should be avoided wherever possible.

Example
-------
A possible use case::
Expand All @@ -287,32 +283,20 @@ async def some_command(self, ctx, item: str):
user = ctx.author

# Where the value of item is the name of the data field in Config
await ctx.send(await self.conf.user(user).get_attr(item))
await ctx.send(await self.conf.user(user).get_attr(item).foo())

Parameters
----------
item : str
The name of the data field in `Config`.
default
This is an optional override to the registered default for this
item.
resolve : bool
If this is :code:`True` this function will return a coroutine that
resolves to a "real" data value when awaited. If :code:`False`,
this method acts the same as `__getattr__`.

Returns
-------
`types.coroutine` or `Value` or `Group`
The attribute which was requested, its type depending on the value
of :code:`resolve`.
`Value` or `Group`
The attribute which was requested.

"""
value = getattr(self, item)
if resolve:
return value(default=default)
else:
return value
return self.__getattr__(item)

async def get_raw(self, *nested_path: str, default=...):
"""
Expand Down Expand Up @@ -350,6 +334,16 @@ async def get_raw(self, *nested_path: str, default=...):
"""
path = [str(p) for p in nested_path]

if default is ...:
poss_default = self.defaults
for ident in path:
try:
poss_default = poss_default[ident]
except KeyError:
break
else:
default = poss_default

try:
return deepcopy(await self.driver.get(*self.identifiers, *path))
except KeyError:
Expand Down Expand Up @@ -398,27 +392,6 @@ async def set(self, value):
)
await super().set(value)

async def set_attr(self, item: str, value):
"""Set an attribute by its name.

Similar to `get_attr` in the way it can be used to dynamically set
attributes by name.

Note
----
Use of this method should be avoided wherever possible.

Parameters
----------
item : str
The name of the attribute being set.
value
The raw data value to set the attribute as.

"""
value_obj = getattr(self, item)
await value_obj.set(value)

async def set_raw(self, *nested_path: str, value):
"""
Allows a developer to set data as if it was stored in a standard
Expand Down
29 changes: 15 additions & 14 deletions redbot/core/modlog.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ async def edit(self, data: dict):
case_emb = await self.message_content()
await self.message.edit(embed=case_emb)

await _conf.guild(self.guild).cases.set_attr(
str(self.case_number), self.to_json()
await _conf.guild(self.guild).cases.set_raw(
str(self.case_number), value=self.to_json()
)

async def message_content(self):
Expand Down Expand Up @@ -245,7 +245,7 @@ async def to_json(self):
"case_str": self.case_str,
"audit_type": self.audit_type
}
await _conf.casetypes.set_attr(self.name, data)
await _conf.casetypes.set_raw(self.name, value=data)

async def is_enabled(self) -> bool:
"""
Expand All @@ -262,8 +262,8 @@ async def is_enabled(self) -> bool:
"""
if not self.guild:
return False
return await _conf.guild(self.guild).casetypes.get_attr(self.name,
self.default_setting)
return await _conf.guild(self.guild).casetypes.get_raw(
self.name, default=self.default_setting)

async def set_enabled(self, enabled: bool):
"""
Expand All @@ -275,7 +275,7 @@ async def set_enabled(self, enabled: bool):
True if the case should be enabled, otherwise False"""
if not self.guild:
return
await _conf.guild(self.guild).casetypes.set_attr(self.name, enabled)
await _conf.guild(self.guild).casetypes.set_raw(self.name, value=enabled)

@classmethod
def from_json(cls, data: dict):
Expand Down Expand Up @@ -310,7 +310,7 @@ async def get_next_case_number(guild: discord.Guild) -> str:

"""
cases = sorted(
(await _conf.guild(guild).get_attr("cases")),
(await _conf.guild(guild).get_raw("cases")),
key=lambda x: int(x),
reverse=True
)
Expand Down Expand Up @@ -342,11 +342,12 @@ async def get_case(case_number: int, guild: discord.Guild,
If there is no case for the specified number

"""
case = await _conf.guild(guild).cases.get_attr(str(case_number))
if case is None:
try:
case = await _conf.guild(guild).cases.get_raw(str(case_number))
except KeyError as e:
raise RuntimeError(
"That case does not exist for guild {}".format(guild.name)
)
) from e
mod_channel = await get_modlog_channel(guild)
return await Case.from_json(mod_channel, bot, case)

Expand All @@ -368,7 +369,7 @@ async def get_all_cases(guild: discord.Guild, bot: Red) -> List[Case]:
A list of all cases for the guild

"""
cases = await _conf.guild(guild).get_attr("cases")
cases = await _conf.guild(guild).get_raw("cases")
case_numbers = list(cases.keys())
case_list = []
for case in case_numbers:
Expand Down Expand Up @@ -440,7 +441,7 @@ async def create_case(guild: discord.Guild, created_at: datetime, action_type: s
case_emb = await case.message_content()
msg = await mod_channel.send(embed=case_emb)
case.message = msg
await _conf.guild(guild).cases.set_attr(str(next_case_number), case.to_json())
await _conf.guild(guild).cases.set_raw(str(next_case_number), value=case.to_json())
return case


Expand All @@ -459,7 +460,7 @@ async def get_casetype(name: str, guild: discord.Guild=None) -> Union[CaseType,
-------
CaseType or None
"""
casetypes = await _conf.get_attr("casetypes")
casetypes = await _conf.get_raw("casetypes")
if name in casetypes:
data = casetypes[name]
data["name"] = name
Expand All @@ -480,7 +481,7 @@ async def get_all_casetypes(guild: discord.Guild=None) -> List[CaseType]:
A list of case types

"""
casetypes = await _conf.get_attr("casetypes")
casetypes = await _conf.get_raw("casetypes", default={})
typelist = []
for ct in casetypes.keys():
data = casetypes[ct]
Expand Down
13 changes: 6 additions & 7 deletions tests/core/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,14 @@ async def test_set_channel_no_register(config, empty_channel):
# Dynamic attribute testing
@pytest.mark.asyncio
async def test_set_dynamic_attr(config):
await config.set_attr("foobar", True)
await config.set_raw("foobar", value=True)

assert await config.foobar() is True


@pytest.mark.asyncio
async def test_get_dynamic_attr(config):
assert await config.get_attr("foobaz", True) is True
assert await config.get_raw("foobaz", default=True) is True


# Member Group testing
Expand Down Expand Up @@ -299,13 +299,12 @@ async def test_member_clear_all(config, member_factory):


@pytest.mark.asyncio
async def test_clear_value(config_fr):
config_fr.register_global(foo=False)
await config_fr.foo.set(True)
await config_fr.foo.clear()
async def test_clear_value(config):
await config.foo.set(True)
await config.foo.clear()

with pytest.raises(KeyError):
await config_fr.get_raw('foo')
await config.get_raw('foo')


# Get All testing
Expand Down