Skip to content

Commit

Permalink
Modify KedroContext with frozen attributes instead of frozen class (#…
Browse files Browse the repository at this point in the history
…3300)

* modify release note

Signed-off-by: Nok <[email protected]>

* Make KedroContext partial frozen

Signed-off-by: Nok <[email protected]>

* fix KedroContext and add test for public and internal attributes

Signed-off-by: Nok <[email protected]>

* Fix tests

Signed-off-by: Nok <[email protected]>

* Rearrange the `KedroContext` arguments

Signed-off-by: Nok <[email protected]>

* improve test_set_new_attribute

Signed-off-by: Nok <[email protected]>

* Unfreeze the Kedrocontext attributes

Signed-off-by: Nok <[email protected]>

* clean up

Signed-off-by: Nok <[email protected]>

* remove redundant test

Signed-off-by: Nok <[email protected]>

---------

Signed-off-by: Nok <[email protected]>
Signed-off-by: Nok Lam Chan <[email protected]>
  • Loading branch information
noklam authored Nov 21, 2023
1 parent c1cf255 commit f285f2c
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 19 deletions.
2 changes: 1 addition & 1 deletion RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ We are grateful to the following for submitting PRs that contributed to this rel

## Bug fixes and other changes
* Removed fatal error from being logged when a Kedro session is created in a directory without git.
* `KedroContext` is now an `attrs`'s frozen class and `config_loader` is available as public attribute.
* `KedroContext` is now a attr's dataclass, `config_loader` is available as public attribute.
* Fixed `CONFIG_LOADER_CLASS` validation so that `TemplatedConfigLoader` can be specified in settings.py. Any `CONFIG_LOADER_CLASS` must be a subclass of `AbstractConfigLoader`.
* Added runner name to the `run_params` dictionary used in pipeline hooks.
* Updated [Databricks documentation](https://docs.kedro.org/en/0.18.1/deployment/databricks.html) to include how to get it working with IPython extension and Kedro-Viz.
Expand Down
36 changes: 28 additions & 8 deletions kedro/framework/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from urllib.parse import urlparse
from warnings import warn

from attrs import field, frozen
from attrs import define, field
from omegaconf import OmegaConf
from pluggy import PluginManager

Expand Down Expand Up @@ -142,18 +142,38 @@ def _expand_full_path(project_path: str | Path) -> Path:
return Path(project_path).expanduser().resolve()


@frozen
@define(slots=False) # Enable setting new attributes to `KedroContext`
class KedroContext:
"""``KedroContext`` is the base class which holds the configuration and
Kedro's main functionality.
Create a context object by providing the root of a Kedro project and
the environment configuration subfolders (see ``kedro.config.OmegaConfigLoader``)
Raises:
KedroContextError: If there is a mismatch
between Kedro project version and package version.
Args:
project_path: Project path to define the context for.
config_loader: Kedro's ``OmegaConfigLoader`` for loading the configuration files.
env: Optional argument for configuration default environment to be used
for running the pipeline. If not specified, it defaults to "local".
package_name: Package name for the Kedro project the context is
created for.
hook_manager: The ``PluginManager`` to activate hooks, supplied by the session.
extra_params: Optional dictionary containing extra project parameters.
If specified, will update (and therefore take precedence over)
the parameters retrieved from the project configuration.
"""

_package_name: str
project_path: Path = field(converter=_expand_full_path)
config_loader: AbstractConfigLoader
_hook_manager: PluginManager
env: str | None = None
_extra_params: dict[str, Any] | None = field(default=None, converter=deepcopy)
project_path: Path = field(init=True, converter=_expand_full_path)
config_loader: AbstractConfigLoader = field(init=True)
env: str | None = field(init=True)
_package_name: str = field(init=True)
_hook_manager: PluginManager = field(init=True)
_extra_params: dict[str, Any] | None = field(
init=True, default=None, converter=deepcopy
)

@property
def catalog(self) -> DataCatalog:
Expand Down
13 changes: 6 additions & 7 deletions tests/framework/context/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import pytest
import toml
import yaml
from attrs.exceptions import FrozenInstanceError
from pandas.testing import assert_frame_equal

from kedro import __version__ as kedro_version
Expand Down Expand Up @@ -201,11 +200,11 @@ def dummy_context(tmp_path, prepare_project_dir, env, extra_params):
context = session.load_context()
config_loader = context.config_loader
context = KedroContext(
MOCK_PACKAGE_NAME,
str(tmp_path),
project_path=str(tmp_path),
config_loader=config_loader,
hook_manager=_create_hook_manager(),
env=env,
package_name=MOCK_PACKAGE_NAME,
hook_manager=_create_hook_manager(),
extra_params=extra_params,
)

Expand All @@ -218,9 +217,9 @@ def test_attributes(self, tmp_path, dummy_context):
assert isinstance(dummy_context.project_path, Path)
assert dummy_context.project_path == tmp_path.resolve()

def test_immutable_instance(self, dummy_context):
with pytest.raises(FrozenInstanceError):
dummy_context.catalog = 1
def test_set_new_attribute(self, dummy_context):
dummy_context.mlflow = 1
assert dummy_context.mlflow == 1

def test_get_catalog_always_using_absolute_path(self, dummy_context):
config_loader = dummy_context.config_loader
Expand Down
44 changes: 41 additions & 3 deletions tests/framework/session/test_session.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import logging
import re
import subprocess
import sys
import textwrap
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Type
from unittest.mock import create_autospec

import pytest
import toml
Expand All @@ -13,7 +16,6 @@
from kedro import __version__ as kedro_version
from kedro.config import AbstractConfigLoader, OmegaConfigLoader
from kedro.framework.cli.utils import _split_params
from kedro.framework.context import KedroContext
from kedro.framework.project import (
LOGGING,
ValidationError,
Expand All @@ -23,7 +25,7 @@
_ProjectSettings,
)
from kedro.framework.session import KedroSession
from kedro.framework.session.session import KedroSessionError
from kedro.framework.session.session import KedroContext, KedroSessionError
from kedro.framework.session.shelvestore import ShelveStore
from kedro.framework.session.store import BaseSessionStore

Expand All @@ -43,6 +45,36 @@ class BadConfigLoader:
"""


ATTRS_ATTRIBUTE = "__attrs_attrs__"

NEW_TYPING = sys.version_info[:3] >= (3, 7, 0) # PEP 560


def create_attrs_autospec(spec: Type, spec_set: bool = True) -> Any:
"""Creates a mock of an attr class (creates mocks recursively on all attributes).
https://github.com/python-attrs/attrs/issues/462#issuecomment-1134656377
:param spec: the spec to mock
:param spec_set: if True, AttributeError will be raised if an attribute that is not in the spec is set.
"""

if not hasattr(spec, ATTRS_ATTRIBUTE):
raise TypeError(f"{spec!r} is not an attrs class")
mock = create_autospec(spec, spec_set=spec_set)
for attribute in getattr(spec, ATTRS_ATTRIBUTE):
attribute_type = attribute.type
if NEW_TYPING:
# A[T] does not get a copy of __dict__ from A(Generic[T]) anymore, use __origin__ to get it
while hasattr(attribute_type, "__origin__"):
attribute_type = attribute_type.__origin__
if hasattr(attribute_type, ATTRS_ATTRIBUTE):
mock_attribute = create_attrs_autospec(attribute_type, spec_set)
else:
mock_attribute = create_autospec(attribute_type, spec_set=spec_set)
object.__setattr__(mock, attribute.name, mock_attribute)
return mock


@pytest.fixture
def mock_runner(mocker):
mock_runner = mocker.patch(
Expand All @@ -55,7 +87,12 @@ def mock_runner(mocker):

@pytest.fixture
def mock_context_class(mocker):
return mocker.patch("kedro.framework.session.session.KedroContext", autospec=True)
mock_cls = create_attrs_autospec(KedroContext)
return mocker.patch(
"kedro.framework.session.session.KedroContext",
autospec=True,
return_value=mock_cls,
)


def _mock_imported_settings_paths(mocker, mock_settings):
Expand All @@ -75,6 +112,7 @@ def mock_settings(mocker):
@pytest.fixture
def mock_settings_context_class(mocker, mock_context_class):
class MockSettings(_ProjectSettings):
# dynaconf automatically deleted some attribute when the class is MagicMock
_CONTEXT_CLASS = Validator(
"CONTEXT_CLASS", default=lambda *_: mock_context_class
)
Expand Down

0 comments on commit f285f2c

Please sign in to comment.