Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve typing of settings and add plugin settings object #1833

Merged
Merged
10 changes: 6 additions & 4 deletions aries_cloudagent/config/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from .error import ArgsParseError
from .util import BoundedInt, ByteSize

from .plugin_settings import PLUGIN_CONFIG_KEY

CAT_PROVISION = "general"
CAT_START = "start"
CAT_UPGRADE = "upgrade"
Expand Down Expand Up @@ -630,17 +632,17 @@ def get_settings(self, args: Namespace) -> dict:

if args.plugin_config:
with open(args.plugin_config, "r") as stream:
settings["plugin_config"] = yaml.safe_load(stream)
settings[PLUGIN_CONFIG_KEY] = yaml.safe_load(stream)

if args.plugin_config_values:
if "plugin_config" not in settings:
settings["plugin_config"] = {}
if PLUGIN_CONFIG_KEY not in settings:
settings[PLUGIN_CONFIG_KEY] = {}

for value_str in chain(*args.plugin_config_values):
key, value = value_str.split("=", maxsplit=1)
value = yaml.safe_load(value)
deepmerge.always_merger.merge(
settings["plugin_config"],
settings[PLUGIN_CONFIG_KEY],
reduce(lambda v, k: {k: v}, key.split(".")[::-1], value),
)

Expand Down
23 changes: 13 additions & 10 deletions aries_cloudagent/config/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Configuration base classes."""

from abc import ABC, abstractmethod
from typing import Mapping, Optional, Type, TypeVar
from typing import Any, Iterator, Mapping, Optional, Type, TypeVar

from ..core.error import BaseError

Expand All @@ -16,11 +16,11 @@ class SettingsError(ConfigError):
"""The base exception raised by `BaseSettings` implementations."""


class BaseSettings(Mapping[str, object]):
class BaseSettings(Mapping[str, Any]):
"""Base settings class."""

@abstractmethod
def get_value(self, *var_names, default=None):
def get_value(self, *var_names, default: Optional[Any] = None) -> Any:
"""Fetch a setting.

Args:
Expand All @@ -32,7 +32,7 @@ def get_value(self, *var_names, default=None):

"""

def get_bool(self, *var_names, default=None) -> bool:
def get_bool(self, *var_names, default: Optional[bool] = None) -> Optional[bool]:
"""Fetch a setting as a boolean value.

Args:
Expand All @@ -42,9 +42,10 @@ def get_bool(self, *var_names, default=None) -> bool:
value = self.get_value(*var_names, default)
if value is not None:
value = bool(value and value not in ("false", "False", "0"))

return value

def get_int(self, *var_names, default=None) -> int:
def get_int(self, *var_names, default: Optional[int] = None) -> Optional[int]:
"""Fetch a setting as an integer value.

Args:
Expand All @@ -54,9 +55,10 @@ def get_int(self, *var_names, default=None) -> int:
value = self.get_value(*var_names, default)
if value is not None:
value = int(value)

return value

def get_str(self, *var_names, default=None) -> str:
def get_str(self, *var_names, default: Optional[str] = None) -> Optional[str]:
"""Fetch a setting as a string value.

Args:
Expand All @@ -66,10 +68,11 @@ def get_str(self, *var_names, default=None) -> str:
value = self.get_value(*var_names, default=default)
if value is not None:
value = str(value)

return value

@abstractmethod
def __iter__(self):
def __iter__(self) -> Iterator:
"""Iterate settings keys."""

def __getitem__(self, index):
Expand All @@ -91,7 +94,7 @@ def copy(self) -> "BaseSettings":
"""Produce a copy of the settings instance."""

@abstractmethod
def extend(self, other: Mapping[str, object]) -> "BaseSettings":
def extend(self, other: Mapping[str, Any]) -> "BaseSettings":
"""Merge another mapping to produce a new settings instance."""

def __repr__(self) -> str:
Expand All @@ -111,7 +114,7 @@ class BaseInjector(ABC):
def inject(
self,
base_cls: Type[InjectType],
settings: Mapping[str, object] = None,
settings: Optional[Mapping[str, Any]] = None,
) -> InjectType:
"""
Get the provided instance of a given class identifier.
Expand All @@ -129,7 +132,7 @@ def inject(
def inject_or(
self,
base_cls: Type[InjectType],
settings: Mapping[str, object] = None,
settings: Optional[Mapping[str, Any]] = None,
default: Optional[InjectType] = None,
) -> Optional[InjectType]:
"""
Expand Down
4 changes: 2 additions & 2 deletions aries_cloudagent/config/base_context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Base injection context builder classes."""

from abc import ABC, abstractmethod
from typing import Mapping
from typing import Any, Mapping, Optional

from .injection_context import InjectionContext
from .settings import Settings
Expand All @@ -10,7 +10,7 @@
class ContextBuilder(ABC):
"""Base injection context builder class."""

def __init__(self, settings: Mapping[str, object] = None):
def __init__(self, settings: Optional[Mapping[str, Any]] = None):
"""
Initialize an instance of the context builder.

Expand Down
81 changes: 81 additions & 0 deletions aries_cloudagent/config/plugin_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Settings implementation for plugins."""

from typing import Any, Mapping, Optional

from .base import BaseSettings


PLUGIN_CONFIG_KEY = "plugin_config"


class PluginSettings(BaseSettings):
"""Retrieve immutable settings for plugins.

Plugin settings should be retrieved by calling:

PluginSettings.for_plugin(settings, "my_plugin", {"default": "values"})

This will extract the PLUGIN_CONFIG_KEY in "settings" and return a new
PluginSettings instance.
"""

def __init__(self, values: Optional[Mapping[str, Any]] = None):
"""Initialize a Settings object.

Args:
values: An optional dictionary of settings
"""
self._values = {}
if values:
self._values.update(values)

def __contains__(self, index):
"""Define 'in' operator."""
return index in self._values

def __iter__(self):
"""Iterate settings keys."""
return iter(self._values)

def __len__(self):
"""Fetch the length of the mapping."""
return len(self._values)

def __bool__(self):
"""Convert settings to a boolean."""
return True

def copy(self) -> BaseSettings:
"""Produce a copy of the settings instance."""
return PluginSettings(self._values)

def extend(self, other: Mapping[str, Any]) -> BaseSettings:
"""Merge another settings instance to produce a new instance."""
vals = self._values.copy()
vals.update(other)
return PluginSettings(vals)

def get_value(self, *var_names: str, default: Any = None):
"""Fetch a setting.

Args:
var_names: A list of variable name alternatives
default: The default value to return if none are defined
"""
for k in var_names:
if k in self._values:
return self._values[k]
return default

@classmethod
def for_plugin(
cls,
settings: BaseSettings,
plugin: str,
default: Optional[Mapping[str, Any]] = None,
) -> "PluginSettings":
"""Construct a PluginSettings object from another settings object.

PLUGIN_CONFIG_KEY is read from settings.
"""
return cls(settings.get(PLUGIN_CONFIG_KEY, {}).get(plugin, default))
15 changes: 10 additions & 5 deletions aries_cloudagent/config/settings.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Settings implementation."""

from typing import Mapping
from typing import Any, Mapping, MutableMapping, Optional

from .base import BaseSettings
from .plugin_settings import PluginSettings


class Settings(BaseSettings):
class Settings(BaseSettings, MutableMapping[str, Any]):
"""Mutable settings implementation."""

def __init__(self, values: Mapping[str, object] = None):
def __init__(self, values: Optional[Mapping[str, Any]] = None):
"""Initialize a Settings object.

Args:
Expand Down Expand Up @@ -90,12 +91,16 @@ def copy(self) -> BaseSettings:
"""Produce a copy of the settings instance."""
return Settings(self._values)

def extend(self, other: Mapping[str, object]) -> BaseSettings:
def extend(self, other: Mapping[str, Any]) -> BaseSettings:
"""Merge another settings instance to produce a new instance."""
vals = self._values.copy()
vals.update(other)
return Settings(vals)

def update(self, other: Mapping[str, object]):
def update(self, other: Mapping[str, Any]):
"""Update the settings in place."""
self._values.update(other)

def for_plugin(self, plugin: str, default: Optional[Mapping[str, Any]] = None):
"""Retrieve settings for plugin."""
return PluginSettings.for_plugin(self, plugin, default)
25 changes: 25 additions & 0 deletions aries_cloudagent/config/tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

from unittest import TestCase

from aries_cloudagent.config.plugin_settings import PluginSettings

from ..base import SettingsError
from ..settings import Settings
from ..plugin_settings import PLUGIN_CONFIG_KEY


class TestSettings(TestCase):
Expand Down Expand Up @@ -59,3 +62,25 @@ def test_set_default(self):
assert self.test_instance[self.test_key] == self.test_value
self.test_instance.set_default("BOOL", "True")
assert self.test_instance["BOOL"] == "True"

def test_plugin_setting_retrieval(self):
plugin_setting_values = {
"value0": 0,
"value1": 1,
"value2": 2,
"value3": 3,
"value4": 4,
}
self.test_instance[PLUGIN_CONFIG_KEY] = {"my_plugin": plugin_setting_values}

plugin_settings = self.test_instance.for_plugin("my_plugin")
assert isinstance(plugin_settings, PluginSettings)
assert plugin_settings._values == plugin_setting_values
for key in plugin_setting_values:
assert key in plugin_settings
assert plugin_settings[key] == plugin_setting_values[key]
assert plugin_settings.get_value(key) == plugin_setting_values[key]
with self.assertRaises(KeyError):
plugin_settings["MISSING"]
assert len(plugin_settings) == 5
assert len(plugin_settings) == 5
5 changes: 3 additions & 2 deletions aries_cloudagent/transport/inbound/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import logging
from typing import Optional

from aiohttp import WSMessage, WSMsgType, web

Expand Down Expand Up @@ -30,10 +31,10 @@ def __init__(self, host: str, port: int, create_session, **kwargs) -> None:
self.host = host
self.port = port
self.site: web.BaseSite = None
self.heartbeat_interval: int = self.root_profile.settings.get_int(
self.heartbeat_interval: Optional[int] = self.root_profile.settings.get_int(
"transport.ws.heartbeat_interval"
)
self.timout_interval: int = self.root_profile.settings.get_int(
self.timout_interval: Optional[int] = self.root_profile.settings.get_int(
"transport.ws.timout_interval"
)

Expand Down