Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Dataset Factory Patterns to Experiment Tracking #1824

Merged
merged 6 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ Please follow the established format:
-->
# Upcoming Release

## Major features and improvements

- Add Dataset Factory Patterns to Experiment Tracking. (#1824)

## Bug fixes and other changes

- Add support for `JSONDataset` preview. (#1800)
Expand Down
28 changes: 24 additions & 4 deletions package/kedro_viz/data_access/managers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""`kedro_viz.data_access.managers` defines data access managers."""

# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-instance-attributes,protected-access
import logging
from collections import defaultdict
from typing import Dict, List, Set, Union
Expand Down Expand Up @@ -69,16 +69,36 @@ def set_db_session(self, db_session_class: sessionmaker):
"""Set db session on repositories that need it."""
self.runs.set_db_session(db_session_class)

def add_catalog(self, catalog: DataCatalog):
def resolve_dataset_factory_patterns(
self, catalog: DataCatalog, pipelines: Dict[str, KedroPipeline]
):
"""Resolve dataset factory patterns in data catalog by matching
them against the datasets in the pipelines.
"""
for pipeline in pipelines.values():
if hasattr(pipeline, "data_sets"):
# Support for Kedro 0.18.x
datasets = pipeline.data_sets()
else:
datasets = pipeline.datasets()

for dataset_name in datasets:
try:
catalog._get_dataset(dataset_name, suggest=False)
# pylint: disable=broad-except
except Exception: # pragma: no cover
continue

def add_catalog(self, catalog: DataCatalog, pipelines: Dict[str, KedroPipeline]):
"""Resolve dataset factory patterns, add the catalog to the CatalogRepository
and relevant tracking datasets to TrackingDatasetRepository.

Args:
catalog: The DataCatalog instance to add.
pipelines: A dictionary which holds project pipelines
"""

# TODO: Implement dataset factory pattern discovery for
# experiment tracking datasets.
self.resolve_dataset_factory_patterns(catalog, pipelines)

self.catalog.set_catalog(catalog)

Expand Down
2 changes: 1 addition & 1 deletion package/kedro_viz/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def populate_data(
session_class = make_db_session_factory(session_store.location)
data_access_manager.set_db_session(session_class)

data_access_manager.add_catalog(catalog)
data_access_manager.add_catalog(catalog, pipelines)

# add dataset stats before adding pipelines as the data nodes
# need stats information and they are created during add_pipelines
Expand Down
27 changes: 21 additions & 6 deletions package/tests/test_api/test_graphql/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,11 @@ def test_run_tracking_data_query(
client,
example_tracking_catalog,
data_access_manager_with_runs,
example_pipelines,
):
data_access_manager_with_runs.add_catalog(example_tracking_catalog)
data_access_manager_with_runs.add_catalog(
example_tracking_catalog, example_pipelines
)
example_run_id = example_run_ids[0]

response = client.post(
Expand Down Expand Up @@ -170,9 +173,15 @@ def test_run_tracking_data_query(
assert response.json() == expected_response

def test_metrics_data(
self, client, example_tracking_catalog, data_access_manager_with_runs
self,
client,
example_tracking_catalog,
data_access_manager_with_runs,
example_pipelines,
):
data_access_manager_with_runs.add_catalog(example_tracking_catalog)
data_access_manager_with_runs.add_catalog(
example_tracking_catalog, example_pipelines
)

response = client.post(
"/graphql",
Expand Down Expand Up @@ -286,8 +295,11 @@ def test_graphql_run_tracking_data(
data_access_manager_with_runs,
show_diff,
expected_response,
example_pipelines,
):
data_access_manager_with_runs.add_catalog(example_multiple_run_tracking_catalog)
data_access_manager_with_runs.add_catalog(
example_multiple_run_tracking_catalog, example_pipelines
)

response = client.post(
"/graphql",
Expand Down Expand Up @@ -343,9 +355,11 @@ def test_graphql_run_tracking_data_at_least_one_empty_run(
data_access_manager_with_runs,
show_diff,
expected_response,
example_pipelines,
):
data_access_manager_with_runs.add_catalog(
example_multiple_run_tracking_catalog_at_least_one_empty_run
example_multiple_run_tracking_catalog_at_least_one_empty_run,
example_pipelines,
)

response = client.post(
Expand Down Expand Up @@ -379,9 +393,10 @@ def test_graphql_run_tracking_data_all_empty_runs(
data_access_manager_with_runs,
show_diff,
expected_response,
example_pipelines,
):
data_access_manager_with_runs.add_catalog(
example_multiple_run_tracking_catalog_all_empty_runs
example_multiple_run_tracking_catalog_all_empty_runs, example_pipelines
)

response = client.post(
Expand Down
Loading