Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce a sentinel value _NO_VALUE to improve Global resolvers to support defaults 0 or None #2976

Merged
merged 20 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions kedro/config/omegaconf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

_config_logger = logging.getLogger(__name__)

_NO_VALUE = object()


class OmegaConfigLoader(AbstractConfigLoader):
"""Recursively scan directories (config paths) contained in ``conf_source`` for
Expand Down Expand Up @@ -316,31 +318,28 @@ def _register_globals_resolver(self):
"""Register the globals resolver"""
OmegaConf.register_new_resolver(
"globals",
lambda variable, default_value=None: self._get_globals_value(
variable, default_value
),
self._get_globals_value,
replace=True,
)

def _get_globals_value(self, variable, default_value):
def _get_globals_value(self, variable, default_value=_NO_VALUE):
"""Return the globals values to the resolver"""
if variable.startswith("_"):
raise InterpolationResolutionError(
"Keys starting with '_' are not supported for globals."
)
keys = variable.split(".")
value = self["globals"]
for k in keys:
value = value.get(k)
if not value:
if default_value:
_config_logger.debug(
f"Using the default value for the global variable {variable}."
)
return default_value
msg = f"Globals key '{variable}' not found and no default value provided. "
raise InterpolationResolutionError(msg)
return value
global_omegaconf = OmegaConf.create(self["globals"])
interpolated_value = OmegaConf.select(
global_omegaconf, variable, default=default_value
)
if interpolated_value != _NO_VALUE:
return interpolated_value
else:
raise InterpolationResolutionError(
"Default value is not defined for {$globals: {keys}}.".replace(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this error also thrown if the key isn't found? In that case the message should probably be as before: "Globals key '{variable}' not found and no default value provided. "

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Currently there are 3 paths it can take:

  1. exists in globals.yml (take the global value as is)
  2. Not define in globals.yml, but default provided
  3. Not define in globals.yml and not default provided - raise InterpolationResolutionError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"{keys}", variable
)
)

@staticmethod
def _register_new_resolvers(resolvers: dict[str, Callable]):
Expand Down
55 changes: 48 additions & 7 deletions tests/config/test_omegaconf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def test_custom_resolvers(self, tmp_path):
def test_globals(self, tmp_path):
globals_params = tmp_path / _BASE_ENV / "globals.yml"
globals_config = {
"x": 34,
"x": 0,
}
_write_yaml(globals_params, globals_config)
conf = OmegaConfigLoader(tmp_path, default_run_env="")
Expand Down Expand Up @@ -704,7 +704,6 @@ def test_globals_resolution(self, tmp_path):
_write_yaml(globals_params, globals_config)
_write_yaml(base_catalog, catalog_config)
conf = OmegaConfigLoader(tmp_path, default_run_env="")
assert OmegaConf.has_resolver("globals")
# Globals are resolved correctly in parameter
assert conf["parameters"]["my_param"] == globals_config["x"]
# The default value is used if the key does not exist
Expand Down Expand Up @@ -760,25 +759,67 @@ def test_globals_across_env(self, tmp_path):
# Base global value is accessible to local params
assert conf["parameters"]["param2"] == base_globals_config["x"]

def test_bad_globals(self, tmp_path):
def test_globals_default(self, tmp_path):
base_params = tmp_path / _BASE_ENV / "parameters.yml"
base_globals = tmp_path / _BASE_ENV / "globals.yml"
base_param_config = {
"int": "${globals:x.NOT_EXIST, 1}",
"str": "${globals: x.NOT_EXIST, '2'}",
"dummy": "${globals: x.DUMMY.DUMMY, '2'}",
}
base_globals_config = {"x": {"DUMMY": 3}}
_write_yaml(base_params, base_param_config)
_write_yaml(base_globals, base_globals_config)
conf = OmegaConfigLoader(tmp_path, default_run_env="")
# Default value is being used as int
assert conf["parameters"]["int"] == 1
# Default value is being used as str
assert conf["parameters"]["str"] == "2"
# Test when x.DUMMY is not a dictionary it should still work
assert conf["parameters"]["dummy"] == "2"

def test_globals_default_none(self, tmp_path):
base_params = tmp_path / _BASE_ENV / "parameters.yml"
base_globals = tmp_path / _BASE_ENV / "globals.yml"
base_param_config = {
"param1": "${globals:x.y}",
"zero": "${globals: x.NOT_EXIST, 0}",
"null": "${globals: x.NOT_EXIST, null}",
"null2": "${globals: x.y}",
}
base_globals_config = {
"x": {
"z": 23,
"y": None,
},
}
_write_yaml(base_params, base_param_config)
_write_yaml(base_globals, base_globals_config)
conf = OmegaConfigLoader(tmp_path, default_run_env="")
# Default value can be 0 or null
assert conf["parameters"]["zero"] == 0
assert conf["parameters"]["null"] is None
# Global value is null
assert conf["parameters"]["null2"] is None

def test_globals_missing_default(self, tmp_path):
base_params = tmp_path / _BASE_ENV / "parameters.yml"
globals_params = tmp_path / _BASE_ENV / "globals.yml"
param_config = {
"NOT_OK": "${globals:nested.NOT_EXIST}",
}
globals_config = {
"nested": {
"y": 42,
},
}
_write_yaml(base_params, param_config)
_write_yaml(globals_params, globals_config)
conf = OmegaConfigLoader(tmp_path, default_run_env="")

with pytest.raises(
InterpolationResolutionError,
match=r"Globals key 'x.y' not found and no default value provided.",
InterpolationResolutionError, match="Default value is not defined for"
):
conf["parameters"]["param1"]
conf["parameters"]["NOT_OK"]

def test_bad_globals_underscore(self, tmp_path):
base_params = tmp_path / _BASE_ENV / "parameters.yml"
Expand Down