diff --git a/tests/config.json b/tests/config.json index 43d78af..69f0c6f 100644 --- a/tests/config.json +++ b/tests/config.json @@ -1,5 +1,10 @@ { "arg1": "stuff", "opt1": "things", - "opt2": "nothing" + "opt2": "nothing", + "simple_app": { + "arg1": "stuff2", + "opt1": "things2", + "opt2": "nothing2" + } } \ No newline at end of file diff --git a/tests/config.toml b/tests/config.toml index 9c43d5c..0e83238 100644 --- a/tests/config.toml +++ b/tests/config.toml @@ -1,3 +1,8 @@ arg1 = "stuff" opt1 = "things" -opt2 = "nothing" \ No newline at end of file +opt2 = "nothing" + +[simple_app] +arg1 = "stuff2" +opt1 = "things2" +opt2 = "nothing2" \ No newline at end of file diff --git a/tests/config.yml b/tests/config.yml index 9448cf3..d8362d6 100644 --- a/tests/config.yml +++ b/tests/config.yml @@ -1,3 +1,8 @@ arg1: stuff opt1: things -opt2: nothing \ No newline at end of file +opt2: nothing + +simple_app: + arg1: stuff2 + opt1: things2 + opt2: nothing2 \ No newline at end of file diff --git a/tests/test_example.py b/tests/test_example.py index bee6f63..b28ce38 100644 --- a/tests/test_example.py +++ b/tests/test_example.py @@ -217,6 +217,46 @@ def test_simple_example_decorated_default(simple_app_decorated, confs): assert "No such file" in result.stdout, f"Wrong error message for {conf}" +@pytest.mark.parametrize("confs", CONFS, ids=str) +def test_simple_example_decorated_section(simple_app_decorated, confs): + """Test Simple YAML app (decorator).""" + + conf, _, dec = confs + + # skip tests that won't work + if conf.split(".")[-1] in ["ini", "env"]: + return + + _app = simple_app_decorated(dec, section=["simple_app"]) + + result = RUNNER.invoke(_app, ["--help"]) + assert ( + result.exit_code == 0 + ), f"Couldn't get to `--help` for {conf}\n\n{result.stdout}" + + result = RUNNER.invoke(_app, ["--config", conf]) + assert result.exit_code == 0, f"Loading failed for {conf}\n\n{result.stdout}" + assert ( + result.stdout.strip() == "things2 nothing2 stuff2" + ), f"Unexpected output for {conf}" + + result = RUNNER.invoke(_app, ["--config", conf, "others2"]) + assert result.exit_code == 0, f"Loading failed for {conf}\n\n{result.stdout}" + assert ( + result.stdout.strip() == "things2 nothing2 others2" + ), f"Unexpected output for {conf}" + + result = RUNNER.invoke(_app, ["--config", conf, "--opt1", "people2"]) + assert result.exit_code == 0, f"Loading failed for {conf}\n\n{result.stdout}" + assert ( + result.stdout.strip() == "people2 nothing2 stuff2" + ), f"Unexpected output for {conf}" + + result = RUNNER.invoke(_app, ["--config", conf + ".non_existent"]) + assert result.exit_code != 0, f"Should have failed for {conf}\n\n{result.stdout}" + assert "No such file" in result.stdout, f"Wrong error message for {conf}" + + def test_pyproject_example(simple_app): """Test pyproject example.""" diff --git a/typer_config/decorators.py b/typer_config/decorators.py index 69eab02..b79156e 100644 --- a/typer_config/decorators.py +++ b/typer_config/decorators.py @@ -9,13 +9,7 @@ from typer import Option -from .callbacks import ( - conf_callback_factory, - dotenv_conf_callback, - json_conf_callback, - toml_conf_callback, - yaml_conf_callback, -) +from .callbacks import conf_callback_factory from .dumpers import json_dumper, toml_dumper, yaml_dumper from .loaders import ( dotenv_loader, @@ -25,10 +19,10 @@ toml_loader, yaml_loader, ) +from .utils import get_dict_section if TYPE_CHECKING: # pragma: no cover from .__typing import ( - ConfigDict, ConfigDumper, ConfigParameterCallback, FilePath, @@ -106,6 +100,7 @@ def wrapped(*args, **kwargs): # noqa: ANN202,ANN002,ANN003 # default decorators def use_json_config( + section: Optional[List[str]] = None, param_name: TyperParameterName = "config", param_help: str = "Configuration file.", default_value: Optional[TyperParameterValue] = None, @@ -126,6 +121,8 @@ def main(...): ``` Args: + section (List[str], optional): List of nested sections to access in the config. + Defaults to None. param_name (TyperParameterName, optional): name of config parameter. Defaults to "config". param_help (str, optional): config parameter help string. @@ -137,23 +134,24 @@ def main(...): TyperCommandDecorator: decorator to apply to command """ - if default_value is not None: - callback = conf_callback_factory( - loader_transformer( - json_loader, - loader_conditional=lambda param_value: param_value, - param_transformer=lambda param_value: param_value - if param_value - else default_value, + callback = conf_callback_factory( + loader_transformer( + json_loader, + loader_conditional=lambda param_value: param_value, + param_transformer=( + lambda param_value: param_value if param_value else default_value ) + if default_value is not None + else None, + config_transformer=lambda config: get_dict_section(config, section), ) - else: - callback = json_conf_callback + ) return use_config(callback=callback, param_name=param_name, param_help=param_help) def use_yaml_config( + section: Optional[List[str]] = None, param_name: TyperParameterName = "config", param_help: str = "Configuration file.", default_value: Optional[TyperParameterValue] = None, @@ -174,6 +172,8 @@ def main(...): ``` Args: + section (List[str], optional): List of nested sections to access in the config. + Defaults to None. param_name (str, optional): name of config parameter. Defaults to "config". param_help (str, optional): config parameter help string. Defaults to "Configuration file.". @@ -184,23 +184,24 @@ def main(...): TyperCommandDecorator: decorator to apply to command """ - if default_value is not None: - callback = conf_callback_factory( - loader_transformer( - yaml_loader, - loader_conditional=lambda param_value: param_value, - param_transformer=lambda param_value: param_value - if param_value - else default_value, + callback = conf_callback_factory( + loader_transformer( + yaml_loader, + loader_conditional=lambda param_value: param_value, + param_transformer=( + lambda param_value: param_value if param_value else default_value ) + if default_value is not None + else None, + config_transformer=lambda config: get_dict_section(config, section), ) - else: - callback = yaml_conf_callback + ) return use_config(callback=callback, param_name=param_name, param_help=param_help) def use_toml_config( + section: Optional[List[str]] = None, param_name: TyperParameterName = "config", param_help: str = "Configuration file.", default_value: Optional[TyperParameterValue] = None, @@ -221,6 +222,8 @@ def main(...): ``` Args: + section (List[str], optional): List of nested sections to access in the config. + Defaults to None. param_name (str, optional): name of config parameter. Defaults to "config". param_help (str, optional): config parameter help string. Defaults to "Configuration file.". @@ -231,23 +234,24 @@ def main(...): TyperCommandDecorator: decorator to apply to command """ - if default_value is not None: - callback = conf_callback_factory( - loader_transformer( - toml_loader, - loader_conditional=lambda param_value: param_value, - param_transformer=lambda param_value: param_value - if param_value - else default_value, + callback = conf_callback_factory( + loader_transformer( + toml_loader, + loader_conditional=lambda param_value: param_value, + param_transformer=( + lambda param_value: param_value if param_value else default_value ) + if default_value is not None + else None, + config_transformer=lambda config: get_dict_section(config, section), ) - else: - callback = toml_conf_callback + ) return use_config(callback=callback, param_name=param_name, param_help=param_help) def use_dotenv_config( + section: Optional[List[str]] = None, param_name: TyperParameterName = "config", param_help: str = "Configuration file.", default_value: Optional[TyperParameterValue] = None, @@ -268,6 +272,8 @@ def main(...): ``` Args: + section (List[str], optional): List of nested sections to access in the config. + Defaults to None. param_name (str, optional): name of config parameter. Defaults to "config". param_help (str, optional): config parameter help string. Defaults to "Configuration file.". @@ -278,18 +284,18 @@ def main(...): TyperCommandDecorator: decorator to apply to command """ - if default_value is not None: - callback = conf_callback_factory( - loader_transformer( - dotenv_loader, - loader_conditional=lambda param_value: param_value, - param_transformer=lambda param_value: param_value - if param_value - else default_value, + callback = conf_callback_factory( + loader_transformer( + dotenv_loader, + loader_conditional=lambda param_value: param_value, + param_transformer=( + lambda param_value: param_value if param_value else default_value ) + if default_value is not None + else None, + config_transformer=lambda config: get_dict_section(config, section), ) - else: - callback = dotenv_conf_callback + ) return use_config(callback=callback, param_name=param_name, param_help=param_help) @@ -327,31 +333,18 @@ def main(...): TyperCommandDecorator: decorator to apply to command """ - def _get_section(_section: List[str], config: ConfigDict) -> ConfigDict: - for sect in _section: - config = config.get(sect, {}) - - return config - - if default_value is not None: - callback = conf_callback_factory( - loader_transformer( - ini_loader, - loader_conditional=lambda param_value: param_value, - param_transformer=lambda param_value: param_value - if param_value - else default_value, - config_transformer=lambda config: _get_section(section, config), - ) - ) - else: - callback = conf_callback_factory( - loader_transformer( - ini_loader, - loader_conditional=lambda param_value: param_value, - config_transformer=lambda config: _get_section(section, config), + callback = conf_callback_factory( + loader_transformer( + ini_loader, + loader_conditional=lambda param_value: param_value, + param_transformer=( + lambda param_value: param_value if param_value else default_value ) + if default_value is not None + else None, + config_transformer=lambda config: get_dict_section(config, section), ) + ) return use_config(callback=callback, param_name=param_name, param_help=param_help) diff --git a/typer_config/utils.py b/typer_config/utils.py new file mode 100644 index 0000000..e8abe46 --- /dev/null +++ b/typer_config/utils.py @@ -0,0 +1,23 @@ +"""Utilities.""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional + + +def get_dict_section( + _dict: Dict[Any, Any], keys: Optional[List[Any]] = None +) -> Dict[Any, Any]: + """Get section of a dictionary. + + Args: + _dict (Dict[str, Any]): dictionary to access + keys (List[str]): list of keys to successively access in the dictionary + + Returns: + Dict[str, Any]: section of dictionary requested + """ + if keys is not None: + for key in keys: + _dict = _dict.get(key, {}) + + return _dict