Skip to content

Commit

Permalink
[Config] Group.__call__() has same behaviour as Group.all() (#2018)
Browse files Browse the repository at this point in the history
* Make calling groups useful

This makes config.Group.__call__ effectively an alias for Group.all(),
with the added bonus of becoming a context manager.

get_raw has been updated as well to reflect the new behaviour of
__call__.

* Fix unintended side-effects of new behaviour

* Add tests

* Add test for get_raw mixing in defaults

* Another cleanup for relying on old behaviour internally

* Fix bank relying on old behaviour

* Reformat
  • Loading branch information
Tobotimus authored Aug 26, 2018
1 parent 48a7a21 commit dbed24a
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 29 deletions.
13 changes: 6 additions & 7 deletions redbot/core/bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,19 +400,18 @@ async def get_account(member: Union[discord.Member, discord.User]) -> Account:
"""
if await is_global():
acc_data = (await _conf.user(member)()).copy()
default = _DEFAULT_USER.copy()
all_accounts = await _conf.all_users()
else:
acc_data = (await _conf.member(member)()).copy()
default = _DEFAULT_MEMBER.copy()
all_accounts = await _conf.all_members(member.guild)

if acc_data == {}:
acc_data = default
acc_data["name"] = member.display_name
if member.id not in all_accounts:
acc_data = {"name": member.display_name, "created_at": _DEFAULT_MEMBER["created_at"]}
try:
acc_data["balance"] = await get_default_balance(member.guild)
except AttributeError:
acc_data["balance"] = await get_default_balance()
else:
acc_data = all_accounts[member.id]

acc_data["created_at"] = _decode_time(acc_data["created_at"])
return Account(**acc_data)
Expand Down
79 changes: 57 additions & 22 deletions redbot/core/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import collections
from copy import deepcopy
from typing import Union, Tuple, TYPE_CHECKING
from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar, TYPE_CHECKING

import discord

Expand All @@ -13,8 +13,10 @@

log = logging.getLogger("red.config")

_T = TypeVar("_T")

class _ValueCtxManager:

class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]):
"""Context manager implementation of config values.
This class allows mutable config values to be both "get" and "set" from
Expand Down Expand Up @@ -46,7 +48,7 @@ async def __aenter__(self):
)
return self.raw_value

async def __aexit__(self, *exc_info):
async def __aexit__(self, exc_type, exc, tb):
if self.raw_value != self.__original_value:
await self.value_obj.set(self.raw_value)

Expand Down Expand Up @@ -76,14 +78,14 @@ def __init__(self, identifiers: Tuple[str], default_value, driver):
def identifiers(self):
return tuple(str(i) for i in self._identifiers)

async def _get(self, default):
async def _get(self, default=...):
try:
ret = await self.driver.get(*self.identifiers)
except KeyError:
return default if default is not None else self.default
return default if default is not ... else self.default
return ret

def __call__(self, default=None):
def __call__(self, default=...) -> _ValueCtxManager[Any]:
"""Get the literal value of this data element.
Each `Value` object is created by the `Group.__getattr__` method. The
Expand Down Expand Up @@ -187,6 +189,11 @@ def __init__(
def defaults(self):
return deepcopy(self._defaults)

async def _get(self, default: Dict[str, Any] = ...) -> Dict[str, Any]:
default = default if default is not ... else self.defaults
raw = await super()._get(default)
return self.nested_update(raw, default)

# noinspection PyTypeChecker
def __getattr__(self, item: str) -> Union["Group", Value]:
"""Get an attribute of this group.
Expand Down Expand Up @@ -306,6 +313,11 @@ async def get_raw(self, *nested_path: str, default=...):
data = {"foo": {"bar": "baz"}}
d = data["foo"]["bar"]
Note
----
If retreiving a sub-group, the return value of this method will
include registered defaults for values which have not yet been set.
Parameters
----------
nested_path : str
Expand Down Expand Up @@ -339,15 +351,22 @@ async def get_raw(self, *nested_path: str, default=...):
default = poss_default

try:
return await self.driver.get(*self.identifiers, *path)
raw = await self.driver.get(*self.identifiers, *path)
except KeyError:
if default is not ...:
return default
raise
else:
if isinstance(default, dict):
return self.nested_update(raw, default)
return raw

async def all(self) -> dict:
def all(self) -> _ValueCtxManager[Dict[str, Any]]:
"""Get a dictionary representation of this group's data.
The return value of this method can also be used as an asynchronous
context manager, i.e. with :code:`async with` syntax.
Note
----
The return value of this method will include registered defaults for
Expand All @@ -359,16 +378,18 @@ async def all(self) -> dict:
All of this Group's attributes, resolved as raw data values.
"""
return self.nested_update(await self())
return self()

def nested_update(self, current, defaults=None):
def nested_update(
self, current: collections.Mapping, defaults: Dict[str, Any] = ...
) -> Dict[str, Any]:
"""Robust updater for nested dictionaries
If no defaults are passed, then the instance attribute 'defaults'
will be used.
"""
if not defaults:
if defaults is ...:
defaults = self.defaults

for key, value in current.items():
Expand Down Expand Up @@ -844,7 +865,7 @@ def custom(self, group_identifier: str, *identifiers: str):
"""
return self._get_base_group(group_identifier, *identifiers)

async def _all_from_scope(self, scope: str):
async def _all_from_scope(self, scope: str) -> Dict[int, Dict[Any, Any]]:
"""Get a dict of all values from a particular scope of data.
:code:`scope` must be one of the constants attributed to
Expand All @@ -856,12 +877,18 @@ async def _all_from_scope(self, scope: str):
overwritten.
"""
group = self._get_base_group(scope)
dict_ = await group()
ret = {}
for k, v in dict_.items():
data = group.defaults
data.update(v)
ret[int(k)] = data

try:
dict_ = await self.driver.get(*group.identifiers)
except KeyError:
pass
else:
for k, v in dict_.items():
data = group.defaults
data.update(v)
ret[int(k)] = data

return ret

async def all_guilds(self) -> dict:
Expand Down Expand Up @@ -968,13 +995,21 @@ async def all_members(self, guild: discord.Guild = None) -> dict:
ret = {}
if guild is None:
group = self._get_base_group(self.MEMBER)
dict_ = await group()
for guild_id, guild_data in dict_.items():
ret[int(guild_id)] = self._all_members_from_guild(group, guild_data)
try:
dict_ = await self.driver.get(*group.identifiers)
except KeyError:
pass
else:
for guild_id, guild_data in dict_.items():
ret[int(guild_id)] = self._all_members_from_guild(group, guild_data)
else:
group = self._get_base_group(self.MEMBER, guild.id)
guild_data = await group()
ret = self._all_members_from_guild(group, guild_data)
try:
guild_data = await self.driver.get(*group.identifiers)
except KeyError:
pass
else:
ret = self._all_members_from_guild(group, guild_data)
return ret

async def _clear_scope(self, *scopes: str):
Expand Down
36 changes: 36 additions & 0 deletions tests/core/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,39 @@ async def test_set_then_mutate(config):
list1.append("foo")
list1 = await config.list1()
assert "foo" not in list1


@pytest.mark.asyncio
async def test_call_group_fills_defaults(config):
config.register_global(subgroup={"foo": True})
subgroup = await config.subgroup()
assert "foo" in subgroup


@pytest.mark.asyncio
async def test_group_call_ctxmgr_writes(config):
config.register_global(subgroup={"foo": True})
async with config.subgroup() as subgroup:
subgroup["bar"] = False

subgroup = await config.subgroup()
assert subgroup == {"foo": True, "bar": False}


@pytest.mark.asyncio
async def test_all_works_as_ctxmgr(config):
config.register_global(subgroup={"foo": True})
async with config.subgroup.all() as subgroup:
subgroup["bar"] = False

subgroup = await config.subgroup()
assert subgroup == {"foo": True, "bar": False}


@pytest.mark.asyncio
async def test_get_raw_mixes_defaults(config):
config.register_global(subgroup={"foo": True})
await config.subgroup.set_raw("bar", value=False)

subgroup = await config.get_raw("subgroup")
assert subgroup == {"foo": True, "bar": False}

0 comments on commit dbed24a

Please sign in to comment.