diff --git a/RELEASE.md b/RELEASE.md index 2ccf26b08b..ae62fb6b0f 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -7,6 +7,10 @@ Please follow the established format: --> # Upcoming Release +## Major features and improvements + +- Add Dataset Factory Patterns to Experiment Tracking. (#1824) + ## Bug fixes and other changes - Add support for `JSONDataset` preview. (#1800) diff --git a/package/kedro_viz/data_access/managers.py b/package/kedro_viz/data_access/managers.py index 8835e4fe05..4e4e772e5a 100644 --- a/package/kedro_viz/data_access/managers.py +++ b/package/kedro_viz/data_access/managers.py @@ -1,6 +1,6 @@ """`kedro_viz.data_access.managers` defines data access managers.""" -# pylint: disable=too-many-instance-attributes +# pylint: disable=too-many-instance-attributes,protected-access import logging from collections import defaultdict from typing import Dict, List, Set, Union @@ -69,16 +69,36 @@ def set_db_session(self, db_session_class: sessionmaker): """Set db session on repositories that need it.""" self.runs.set_db_session(db_session_class) - def add_catalog(self, catalog: DataCatalog): + def resolve_dataset_factory_patterns( + self, catalog: DataCatalog, pipelines: Dict[str, KedroPipeline] + ): + """Resolve dataset factory patterns in data catalog by matching + them against the datasets in the pipelines. + """ + for pipeline in pipelines.values(): + if hasattr(pipeline, "data_sets"): + # Support for Kedro 0.18.x + datasets = pipeline.data_sets() + else: + datasets = pipeline.datasets() + + for dataset_name in datasets: + try: + catalog._get_dataset(dataset_name, suggest=False) + # pylint: disable=broad-except + except Exception: # pragma: no cover + continue + + def add_catalog(self, catalog: DataCatalog, pipelines: Dict[str, KedroPipeline]): """Resolve dataset factory patterns, add the catalog to the CatalogRepository and relevant tracking datasets to TrackingDatasetRepository. Args: catalog: The DataCatalog instance to add. + pipelines: A dictionary which holds project pipelines """ - # TODO: Implement dataset factory pattern discovery for - # experiment tracking datasets. + self.resolve_dataset_factory_patterns(catalog, pipelines) self.catalog.set_catalog(catalog) diff --git a/package/kedro_viz/server.py b/package/kedro_viz/server.py index 4c8c5ac04f..d439386515 100644 --- a/package/kedro_viz/server.py +++ b/package/kedro_viz/server.py @@ -40,7 +40,7 @@ def populate_data( session_class = make_db_session_factory(session_store.location) data_access_manager.set_db_session(session_class) - data_access_manager.add_catalog(catalog) + data_access_manager.add_catalog(catalog, pipelines) # add dataset stats before adding pipelines as the data nodes # need stats information and they are created during add_pipelines diff --git a/package/tests/test_api/test_graphql/test_queries.py b/package/tests/test_api/test_graphql/test_queries.py index 6367ebb7e5..16cfd36ae4 100644 --- a/package/tests/test_api/test_graphql/test_queries.py +++ b/package/tests/test_api/test_graphql/test_queries.py @@ -68,8 +68,11 @@ def test_run_tracking_data_query( client, example_tracking_catalog, data_access_manager_with_runs, + example_pipelines, ): - data_access_manager_with_runs.add_catalog(example_tracking_catalog) + data_access_manager_with_runs.add_catalog( + example_tracking_catalog, example_pipelines + ) example_run_id = example_run_ids[0] response = client.post( @@ -170,9 +173,15 @@ def test_run_tracking_data_query( assert response.json() == expected_response def test_metrics_data( - self, client, example_tracking_catalog, data_access_manager_with_runs + self, + client, + example_tracking_catalog, + data_access_manager_with_runs, + example_pipelines, ): - data_access_manager_with_runs.add_catalog(example_tracking_catalog) + data_access_manager_with_runs.add_catalog( + example_tracking_catalog, example_pipelines + ) response = client.post( "/graphql", @@ -286,8 +295,11 @@ def test_graphql_run_tracking_data( data_access_manager_with_runs, show_diff, expected_response, + example_pipelines, ): - data_access_manager_with_runs.add_catalog(example_multiple_run_tracking_catalog) + data_access_manager_with_runs.add_catalog( + example_multiple_run_tracking_catalog, example_pipelines + ) response = client.post( "/graphql", @@ -343,9 +355,11 @@ def test_graphql_run_tracking_data_at_least_one_empty_run( data_access_manager_with_runs, show_diff, expected_response, + example_pipelines, ): data_access_manager_with_runs.add_catalog( - example_multiple_run_tracking_catalog_at_least_one_empty_run + example_multiple_run_tracking_catalog_at_least_one_empty_run, + example_pipelines, ) response = client.post( @@ -379,9 +393,10 @@ def test_graphql_run_tracking_data_all_empty_runs( data_access_manager_with_runs, show_diff, expected_response, + example_pipelines, ): data_access_manager_with_runs.add_catalog( - example_multiple_run_tracking_catalog_all_empty_runs + example_multiple_run_tracking_catalog_all_empty_runs, example_pipelines ) response = client.post( diff --git a/package/tests/test_data_access/test_managers.py b/package/tests/test_data_access/test_managers.py index ce08fa56e1..c81f9819f8 100644 --- a/package/tests/test_data_access/test_managers.py +++ b/package/tests/test_data_access/test_managers.py @@ -9,6 +9,7 @@ from kedro_viz.constants import DEFAULT_REGISTERED_PIPELINE_ID, ROOT_MODULAR_PIPELINE_ID from kedro_viz.data_access.managers import DataAccessManager +from kedro_viz.data_access.repositories.catalog import CatalogRepository from kedro_viz.models.flowchart import ( DataNode, GraphEdge, @@ -24,10 +25,14 @@ def identity(x): class TestAddCatalog: - def test_add_catalog(self, data_access_manager: DataAccessManager): + def test_add_catalog( + self, + data_access_manager: DataAccessManager, + example_pipelines: Dict[str, Pipeline], + ): dataset = CSVDataset(filepath="dataset.csv") catalog = DataCatalog(datasets={"dataset": dataset}) - data_access_manager.add_catalog(catalog) + data_access_manager.add_catalog(catalog, example_pipelines) assert data_access_manager.catalog.get_catalog() is catalog @@ -65,7 +70,11 @@ def test_add_node_with_modular_pipeline( "uk.data_science.modular_pipeline", ] - def test_add_node_input(self, data_access_manager: DataAccessManager): + def test_add_node_input( + self, + data_access_manager: DataAccessManager, + example_pipelines: Dict[str, Pipeline], + ): dataset = CSVDataset(filepath="dataset.csv") dataset_name = "x" registered_pipeline_id = "my_pipeline" @@ -80,7 +89,7 @@ def test_add_node_input(self, data_access_manager: DataAccessManager): catalog = DataCatalog( datasets={dataset_name: dataset}, ) - data_access_manager.add_catalog(catalog) + data_access_manager.add_catalog(catalog, example_pipelines) data_access_manager.add_dataset(registered_pipeline_id, dataset_name) data_node = data_access_manager.add_node_input( registered_pipeline_id, dataset_name, task_node @@ -104,11 +113,15 @@ def test_add_node_input(self, data_access_manager: DataAccessManager): } } - def test_add_parameters_as_node_input(self, data_access_manager: DataAccessManager): + def test_add_parameters_as_node_input( + self, + data_access_manager: DataAccessManager, + example_pipelines: Dict[str, Pipeline], + ): parameters = {"train_test_split": 0.1, "num_epochs": 1000} catalog = DataCatalog() catalog.add_feed_dict({"parameters": parameters}) - data_access_manager.add_catalog(catalog) + data_access_manager.add_catalog(catalog, example_pipelines) registered_pipeline_id = "my_pipeline" kedro_node = node(identity, inputs="parameters", outputs="output") task_node = data_access_manager.add_node(registered_pipeline_id, kedro_node) @@ -119,11 +132,13 @@ def test_add_parameters_as_node_input(self, data_access_manager: DataAccessManag assert task_node.parameters == parameters def test_add_single_parameter_as_node_input( - self, data_access_manager: DataAccessManager + self, + data_access_manager: DataAccessManager, + example_pipelines: Dict[str, Pipeline], ): catalog = DataCatalog() catalog.add_feed_dict({"params:train_test_split": 0.1}) - data_access_manager.add_catalog(catalog) + data_access_manager.add_catalog(catalog, example_pipelines) registered_pipeline_id = "my_pipeline" kedro_node = node(identity, inputs="params:train_test_split", outputs="output") task_node = data_access_manager.add_node(registered_pipeline_id, kedro_node) @@ -136,11 +151,12 @@ def test_add_single_parameter_as_node_input( def test_parameters_yaml_namespace_not_added_to_modular_pipelines( self, data_access_manager: DataAccessManager, + example_pipelines: Dict[str, Pipeline], ): parameter_name = "params:uk.data_science.train_test_split.ratio" catalog = DataCatalog() catalog.add_feed_dict({parameter_name: 0.1}) - data_access_manager.add_catalog(catalog) + data_access_manager.add_catalog(catalog, example_pipelines) registered_pipeline_id = "my_pipeline" kedro_node = node( identity, @@ -160,7 +176,11 @@ def test_parameters_yaml_namespace_not_added_to_modular_pipelines( # make sure parameters YAML namespace not accidentally added to the modular pipeline tree assert "uk.data_science.train_test_split" not in modular_pipelines_tree - def test_add_node_output(self, data_access_manager: DataAccessManager): + def test_add_node_output( + self, + data_access_manager: DataAccessManager, + example_pipelines: Dict[str, Pipeline], + ): dataset = CSVDataset(filepath="dataset.csv") registered_pipeline_id = "my_pipeline" dataset_name = "x" @@ -175,7 +195,7 @@ def test_add_node_output(self, data_access_manager: DataAccessManager): catalog = DataCatalog( datasets={dataset_name: dataset}, ) - data_access_manager.add_catalog(catalog) + data_access_manager.add_catalog(catalog, example_pipelines) data_access_manager.add_dataset(registered_pipeline_id, dataset_name) data_node = data_access_manager.add_node_output( registered_pipeline_id, dataset_name, task_node @@ -200,11 +220,15 @@ def test_add_node_output(self, data_access_manager: DataAccessManager): class TestAddDataset: - def test_add_dataset(self, data_access_manager: DataAccessManager): + def test_add_dataset( + self, + data_access_manager: DataAccessManager, + example_pipelines: Dict[str, Pipeline], + ): dataset = CSVDataset(filepath="dataset.csv") dataset_name = "x" catalog = DataCatalog(datasets={dataset_name: dataset}) - data_access_manager.add_catalog(catalog) + data_access_manager.add_catalog(catalog, example_pipelines) data_access_manager.add_dataset("my_pipeline", dataset_name) # dataset should be added as a graph node @@ -217,10 +241,12 @@ def test_add_dataset(self, data_access_manager: DataAccessManager): assert not graph_node.modular_pipelines def test_add_memory_dataset_when_dataset_not_in_catalog( - self, data_access_manager: DataAccessManager + self, + data_access_manager: DataAccessManager, + example_pipelines: Dict[str, Pipeline], ): catalog = DataCatalog() - data_access_manager.add_catalog(catalog) + data_access_manager.add_catalog(catalog, example_pipelines) data_access_manager.add_dataset("my_pipeline", "memory_dataset") # dataset should be added as a graph node nodes_list = data_access_manager.nodes.as_list() @@ -230,14 +256,16 @@ def test_add_memory_dataset_when_dataset_not_in_catalog( assert isinstance(graph_node.kedro_obj, MemoryDataset) def test_add_dataset_with_modular_pipeline( - self, data_access_manager: DataAccessManager + self, + data_access_manager: DataAccessManager, + example_pipelines: Dict[str, Pipeline], ): dataset = CSVDataset(filepath="dataset.csv") dataset_name = "uk.data_science.x" catalog = DataCatalog( datasets={dataset_name: dataset}, ) - data_access_manager.add_catalog(catalog) + data_access_manager.add_catalog(catalog, example_pipelines) data_access_manager.add_dataset("my_pipeline", dataset_name) nodes_list = data_access_manager.nodes.as_list() graph_node: DataNode = nodes_list[0] @@ -246,12 +274,16 @@ def test_add_dataset_with_modular_pipeline( "uk.data_science", ] - def test_add_all_parameters(self, data_access_manager: DataAccessManager): + def test_add_all_parameters( + self, + data_access_manager: DataAccessManager, + example_pipelines: Dict[str, Pipeline], + ): catalog = DataCatalog() catalog.add_feed_dict( {"parameters": {"train_test_split": 0.1, "num_epochs": 1000}} ) - data_access_manager.add_catalog(catalog) + data_access_manager.add_catalog(catalog, example_pipelines) data_access_manager.add_dataset("my_pipeline", "parameters") nodes_list = data_access_manager.nodes.as_list() @@ -264,10 +296,14 @@ def test_add_all_parameters(self, data_access_manager: DataAccessManager): "num_epochs": 1000, } - def test_add_single_parameter(self, data_access_manager: DataAccessManager): + def test_add_single_parameter( + self, + data_access_manager: DataAccessManager, + example_pipelines: Dict[str, Pipeline], + ): catalog = DataCatalog() catalog.add_feed_dict({"params:train_test_split": 0.1}) - data_access_manager.add_catalog(catalog) + data_access_manager.add_catalog(catalog, example_pipelines) data_access_manager.add_dataset("my_pipeline", "params:train_test_split") nodes_list = data_access_manager.nodes.as_list() assert len(nodes_list) == 1 @@ -277,11 +313,13 @@ def test_add_single_parameter(self, data_access_manager: DataAccessManager): assert graph_node.parameter_value == 0.1 def test_add_dataset_with_params_prefix( - self, data_access_manager: DataAccessManager + self, + data_access_manager: DataAccessManager, + example_pipelines: Dict[str, Pipeline], ): catalog = DataCatalog() catalog.add_feed_dict({"params_train_test_split": 0.1}) - data_access_manager.add_catalog(catalog) + data_access_manager.add_catalog(catalog, example_pipelines) data_access_manager.add_dataset("my_pipeline", "params_train_test_split") nodes_list = data_access_manager.nodes.as_list() assert len(nodes_list) == 1 @@ -297,7 +335,7 @@ def test_add_pipelines( example_pipelines: Dict[str, Pipeline], example_catalog: DataCatalog, ): - data_access_manager.add_catalog(example_catalog) + data_access_manager.add_catalog(example_catalog, example_pipelines) data_access_manager.add_pipelines(example_pipelines) assert [p.id for p in data_access_manager.registered_pipelines.as_list()] == [ @@ -343,7 +381,9 @@ def test_add_pipelines_with_transcoded_data( example_transcoded_pipelines: Dict[str, Pipeline], example_transcoded_catalog: DataCatalog, ): - data_access_manager.add_catalog(example_transcoded_catalog) + data_access_manager.add_catalog( + example_transcoded_catalog, example_transcoded_pipelines + ) data_access_manager.add_pipelines(example_transcoded_pipelines) assert any( isinstance(node, TranscodedDataNode) @@ -365,7 +405,7 @@ def test_different_reigstered_pipelines_having_modular_pipeline_with_same_name( ), } - data_access_manager.add_catalog(DataCatalog()) + data_access_manager.add_catalog(DataCatalog(), registered_pipelines) data_access_manager.add_pipelines(registered_pipelines) modular_pipeline_tree = ( data_access_manager.create_modular_pipelines_tree_for_registered_pipeline( @@ -380,7 +420,7 @@ def test_get_default_selected_pipelines_without_default( example_pipelines: Dict[str, Pipeline], example_catalog: DataCatalog, ): - data_access_manager.add_catalog(example_catalog) + data_access_manager.add_catalog(example_catalog, example_pipelines) del example_pipelines[DEFAULT_REGISTERED_PIPELINE_ID] data_access_manager.add_pipelines(example_pipelines) assert not data_access_manager.registered_pipelines.get_pipeline_by_id( @@ -435,7 +475,7 @@ def test_add_pipelines_with_circular_modular_pipelines( registered_pipelines = { "__default__": internal + external, } - data_access_manager.add_catalog(DataCatalog()) + data_access_manager.add_catalog(DataCatalog(), registered_pipelines) data_access_manager.add_pipelines(registered_pipelines) data_access_manager.create_modular_pipelines_tree_for_registered_pipeline( DEFAULT_REGISTERED_PIPELINE_ID @@ -454,3 +494,25 @@ def test_add_pipelines_with_circular_modular_pipelines( digraph.add_edge(edge.source, edge.target) with pytest.raises(nx.NetworkXNoCycle): nx.find_cycle(digraph) + + +class TestResolveDatasetFactoryPatterns: + def test_resolve_dataset_factory_patterns( + self, + example_catalog, + pipeline_with_datasets_mock, + pipeline_with_data_sets_mock, + data_access_manager: DataAccessManager, + ): + pipelines = { + "pipeline1": pipeline_with_datasets_mock, + "pipeline2": pipeline_with_data_sets_mock, + } + new_catalog = CatalogRepository() + new_catalog.set_catalog(example_catalog) + + assert "model_inputs#csv" not in new_catalog.as_dict().keys() + + data_access_manager.resolve_dataset_factory_patterns(example_catalog, pipelines) + + assert "model_inputs#csv" in new_catalog.as_dict().keys() diff --git a/package/tests/test_integrations/test_azure_deployer.py b/package/tests/test_integrations/test_azure_deployer.py index 883eb8b76e..4c524811b4 100644 --- a/package/tests/test_integrations/test_azure_deployer.py +++ b/package/tests/test_integrations/test_azure_deployer.py @@ -67,7 +67,10 @@ def test_upload_static_files( with open(temp_file_path, "w", encoding="utf-8") as temp_file: temp_file.write(mock_html_content) - with mocker.patch("mimetypes.guess_type", return_value=("text/html", None)): + mime_patch = mocker.patch( + "mimetypes.guess_type", return_value=("text/html", None) + ) + with mime_patch: deployer._upload_static_files(tmp_path) deployer._fs.write_bytes.assert_called_once_with( path="abfs://$web/test_file.html", diff --git a/package/tests/test_integrations/test_gcp_deployer.py b/package/tests/test_integrations/test_gcp_deployer.py index 0620024a61..e8278cedb7 100644 --- a/package/tests/test_integrations/test_gcp_deployer.py +++ b/package/tests/test_integrations/test_gcp_deployer.py @@ -48,7 +48,10 @@ def test_upload_static_files( with open(temp_file_path, "w", encoding="utf-8") as temp_file: temp_file.write(mock_html_content) - with mocker.patch("mimetypes.guess_type", return_value=("text/html", None)): + mime_patch = mocker.patch( + "mimetypes.guess_type", return_value=("text/html", None) + ) + with mime_patch: deployer._upload_static_files(tmp_path) deployer._fs.write_bytes.assert_called_once_with( path=f"gcs://{bucket_name}/test_file.html", diff --git a/package/tests/test_server.py b/package/tests/test_server.py index 5bc5559133..391f9511c7 100644 --- a/package/tests/test_server.py +++ b/package/tests/test_server.py @@ -70,7 +70,9 @@ def test_run_server_from_project( example_pipelines, ): run_server() - patched_data_access_manager.add_catalog.assert_called_once_with(example_catalog) + patched_data_access_manager.add_catalog.assert_called_once_with( + example_catalog, example_pipelines + ) patched_data_access_manager.add_pipelines.assert_called_once_with( example_pipelines ) @@ -93,7 +95,9 @@ def test_run_server_from_project_with_sqlite_store( ): run_server() # assert that when running server, data are added correctly to the data access manager - patched_data_access_manager.add_catalog.assert_called_once_with(example_catalog) + patched_data_access_manager.add_catalog.assert_called_once_with( + example_catalog, example_pipelines + ) patched_data_access_manager.add_pipelines.assert_called_once_with( example_pipelines )