Skip to content

Commit

Permalink
Merge pull request openwallet-foundation#1833 from dbluhm/feature/imp…
Browse files Browse the repository at this point in the history
…rove-settings-types

Improve typing of settings and add plugin settings object
  • Loading branch information
swcurran authored Jun 22, 2022
2 parents 814cd8d + a7d1cbe commit 544ec3a
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 23 deletions.
10 changes: 6 additions & 4 deletions aries_cloudagent/config/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from .error import ArgsParseError
from .util import BoundedInt, ByteSize

from .plugin_settings import PLUGIN_CONFIG_KEY

CAT_PROVISION = "general"
CAT_START = "start"
CAT_UPGRADE = "upgrade"
Expand Down Expand Up @@ -630,17 +632,17 @@ def get_settings(self, args: Namespace) -> dict:

if args.plugin_config:
with open(args.plugin_config, "r") as stream:
settings["plugin_config"] = yaml.safe_load(stream)
settings[PLUGIN_CONFIG_KEY] = yaml.safe_load(stream)

if args.plugin_config_values:
if "plugin_config" not in settings:
settings["plugin_config"] = {}
if PLUGIN_CONFIG_KEY not in settings:
settings[PLUGIN_CONFIG_KEY] = {}

for value_str in chain(*args.plugin_config_values):
key, value = value_str.split("=", maxsplit=1)
value = yaml.safe_load(value)
deepmerge.always_merger.merge(
settings["plugin_config"],
settings[PLUGIN_CONFIG_KEY],
reduce(lambda v, k: {k: v}, key.split(".")[::-1], value),
)

Expand Down
23 changes: 13 additions & 10 deletions aries_cloudagent/config/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Configuration base classes."""

from abc import ABC, abstractmethod
from typing import Mapping, Optional, Type, TypeVar
from typing import Any, Iterator, Mapping, Optional, Type, TypeVar

from ..core.error import BaseError

Expand All @@ -16,11 +16,11 @@ class SettingsError(ConfigError):
"""The base exception raised by `BaseSettings` implementations."""


class BaseSettings(Mapping[str, object]):
class BaseSettings(Mapping[str, Any]):
"""Base settings class."""

@abstractmethod
def get_value(self, *var_names, default=None):
def get_value(self, *var_names, default: Optional[Any] = None) -> Any:
"""Fetch a setting.
Args:
Expand All @@ -32,7 +32,7 @@ def get_value(self, *var_names, default=None):
"""

def get_bool(self, *var_names, default=None) -> bool:
def get_bool(self, *var_names, default: Optional[bool] = None) -> Optional[bool]:
"""Fetch a setting as a boolean value.
Args:
Expand All @@ -42,9 +42,10 @@ def get_bool(self, *var_names, default=None) -> bool:
value = self.get_value(*var_names, default)
if value is not None:
value = bool(value and value not in ("false", "False", "0"))

return value

def get_int(self, *var_names, default=None) -> int:
def get_int(self, *var_names, default: Optional[int] = None) -> Optional[int]:
"""Fetch a setting as an integer value.
Args:
Expand All @@ -54,9 +55,10 @@ def get_int(self, *var_names, default=None) -> int:
value = self.get_value(*var_names, default)
if value is not None:
value = int(value)

return value

def get_str(self, *var_names, default=None) -> str:
def get_str(self, *var_names, default: Optional[str] = None) -> Optional[str]:
"""Fetch a setting as a string value.
Args:
Expand All @@ -66,10 +68,11 @@ def get_str(self, *var_names, default=None) -> str:
value = self.get_value(*var_names, default=default)
if value is not None:
value = str(value)

return value

@abstractmethod
def __iter__(self):
def __iter__(self) -> Iterator:
"""Iterate settings keys."""

def __getitem__(self, index):
Expand All @@ -91,7 +94,7 @@ def copy(self) -> "BaseSettings":
"""Produce a copy of the settings instance."""

@abstractmethod
def extend(self, other: Mapping[str, object]) -> "BaseSettings":
def extend(self, other: Mapping[str, Any]) -> "BaseSettings":
"""Merge another mapping to produce a new settings instance."""

def __repr__(self) -> str:
Expand All @@ -111,7 +114,7 @@ class BaseInjector(ABC):
def inject(
self,
base_cls: Type[InjectType],
settings: Mapping[str, object] = None,
settings: Optional[Mapping[str, Any]] = None,
) -> InjectType:
"""
Get the provided instance of a given class identifier.
Expand All @@ -129,7 +132,7 @@ def inject(
def inject_or(
self,
base_cls: Type[InjectType],
settings: Mapping[str, object] = None,
settings: Optional[Mapping[str, Any]] = None,
default: Optional[InjectType] = None,
) -> Optional[InjectType]:
"""
Expand Down
4 changes: 2 additions & 2 deletions aries_cloudagent/config/base_context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Base injection context builder classes."""

from abc import ABC, abstractmethod
from typing import Mapping
from typing import Any, Mapping, Optional

from .injection_context import InjectionContext
from .settings import Settings
Expand All @@ -10,7 +10,7 @@
class ContextBuilder(ABC):
"""Base injection context builder class."""

def __init__(self, settings: Mapping[str, object] = None):
def __init__(self, settings: Optional[Mapping[str, Any]] = None):
"""
Initialize an instance of the context builder.
Expand Down
81 changes: 81 additions & 0 deletions aries_cloudagent/config/plugin_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Settings implementation for plugins."""

from typing import Any, Mapping, Optional

from .base import BaseSettings


PLUGIN_CONFIG_KEY = "plugin_config"


class PluginSettings(BaseSettings):
"""Retrieve immutable settings for plugins.
Plugin settings should be retrieved by calling:
PluginSettings.for_plugin(settings, "my_plugin", {"default": "values"})
This will extract the PLUGIN_CONFIG_KEY in "settings" and return a new
PluginSettings instance.
"""

def __init__(self, values: Optional[Mapping[str, Any]] = None):
"""Initialize a Settings object.
Args:
values: An optional dictionary of settings
"""
self._values = {}
if values:
self._values.update(values)

def __contains__(self, index):
"""Define 'in' operator."""
return index in self._values

def __iter__(self):
"""Iterate settings keys."""
return iter(self._values)

def __len__(self):
"""Fetch the length of the mapping."""
return len(self._values)

def __bool__(self):
"""Convert settings to a boolean."""
return True

def copy(self) -> BaseSettings:
"""Produce a copy of the settings instance."""
return PluginSettings(self._values)

def extend(self, other: Mapping[str, Any]) -> BaseSettings:
"""Merge another settings instance to produce a new instance."""
vals = self._values.copy()
vals.update(other)
return PluginSettings(vals)

def get_value(self, *var_names: str, default: Any = None):
"""Fetch a setting.
Args:
var_names: A list of variable name alternatives
default: The default value to return if none are defined
"""
for k in var_names:
if k in self._values:
return self._values[k]
return default

@classmethod
def for_plugin(
cls,
settings: BaseSettings,
plugin: str,
default: Optional[Mapping[str, Any]] = None,
) -> "PluginSettings":
"""Construct a PluginSettings object from another settings object.
PLUGIN_CONFIG_KEY is read from settings.
"""
return cls(settings.get(PLUGIN_CONFIG_KEY, {}).get(plugin, default))
15 changes: 10 additions & 5 deletions aries_cloudagent/config/settings.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Settings implementation."""

from typing import Mapping
from typing import Any, Mapping, MutableMapping, Optional

from .base import BaseSettings
from .plugin_settings import PluginSettings


class Settings(BaseSettings):
class Settings(BaseSettings, MutableMapping[str, Any]):
"""Mutable settings implementation."""

def __init__(self, values: Mapping[str, object] = None):
def __init__(self, values: Optional[Mapping[str, Any]] = None):
"""Initialize a Settings object.
Args:
Expand Down Expand Up @@ -90,12 +91,16 @@ def copy(self) -> BaseSettings:
"""Produce a copy of the settings instance."""
return Settings(self._values)

def extend(self, other: Mapping[str, object]) -> BaseSettings:
def extend(self, other: Mapping[str, Any]) -> BaseSettings:
"""Merge another settings instance to produce a new instance."""
vals = self._values.copy()
vals.update(other)
return Settings(vals)

def update(self, other: Mapping[str, object]):
def update(self, other: Mapping[str, Any]):
"""Update the settings in place."""
self._values.update(other)

def for_plugin(self, plugin: str, default: Optional[Mapping[str, Any]] = None):
"""Retrieve settings for plugin."""
return PluginSettings.for_plugin(self, plugin, default)
25 changes: 25 additions & 0 deletions aries_cloudagent/config/tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

from unittest import TestCase

from aries_cloudagent.config.plugin_settings import PluginSettings

from ..base import SettingsError
from ..settings import Settings
from ..plugin_settings import PLUGIN_CONFIG_KEY


class TestSettings(TestCase):
Expand Down Expand Up @@ -59,3 +62,25 @@ def test_set_default(self):
assert self.test_instance[self.test_key] == self.test_value
self.test_instance.set_default("BOOL", "True")
assert self.test_instance["BOOL"] == "True"

def test_plugin_setting_retrieval(self):
plugin_setting_values = {
"value0": 0,
"value1": 1,
"value2": 2,
"value3": 3,
"value4": 4,
}
self.test_instance[PLUGIN_CONFIG_KEY] = {"my_plugin": plugin_setting_values}

plugin_settings = self.test_instance.for_plugin("my_plugin")
assert isinstance(plugin_settings, PluginSettings)
assert plugin_settings._values == plugin_setting_values
for key in plugin_setting_values:
assert key in plugin_settings
assert plugin_settings[key] == plugin_setting_values[key]
assert plugin_settings.get_value(key) == plugin_setting_values[key]
with self.assertRaises(KeyError):
plugin_settings["MISSING"]
assert len(plugin_settings) == 5
assert len(plugin_settings) == 5
5 changes: 3 additions & 2 deletions aries_cloudagent/transport/inbound/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import logging
from typing import Optional

from aiohttp import WSMessage, WSMsgType, web

Expand Down Expand Up @@ -30,10 +31,10 @@ def __init__(self, host: str, port: int, create_session, **kwargs) -> None:
self.host = host
self.port = port
self.site: web.BaseSite = None
self.heartbeat_interval: int = self.root_profile.settings.get_int(
self.heartbeat_interval: Optional[int] = self.root_profile.settings.get_int(
"transport.ws.heartbeat_interval"
)
self.timout_interval: int = self.root_profile.settings.get_int(
self.timout_interval: Optional[int] = self.root_profile.settings.get_int(
"transport.ws.timout_interval"
)

Expand Down

0 comments on commit 544ec3a

Please sign in to comment.