Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: show matplotlib dataset pngs #887

Merged
merged 21 commits into from
Jun 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@ Please follow the established format:

## Major features and improvements

- Allow the display of Matplotlib images in the metadata panel and modal. (#887)

## Bug fixes and other changes

- Added warning message when filtered pipeline is empty. (#864)
- Improve telemetry to track flowchart events. (#865)
- Disabled uvicorn's logger so that log messages are no longer duplicated. (#870)
- Enhance _Apply and close_ behavior of modals. (#875)
- Fix namespace collison when two different registered pipelines have a modular pipeline with the same name. (#871)
- Fix namespace collision when two registered pipelines have a modular pipeline with the same name. (#871)

# Release 4.6.0

Expand Down
4 changes: 4 additions & 0 deletions demo-project/conf/base/catalog_08_reporting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,7 @@ reporting.price_histogram:
reporting.cancellation_policy_grid:
type: demo_project.extras.datasets.image_dataset.ImageDataSet
filepath: ${base_location}/08_reporting/cancellation_policy_grid.png

reporting.matplotlib_image:
type: matplotlib.MatplotlibWriter
filepath: ${base_location}/08_reporting/matplot_lib_single_plot.png

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion demo-project/data/08_reporting/price_histogram.json

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"model_type": "LinearRegression",
"fit_intercept": true,
"copy_X": true,
"positive": false
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"r2_score": 0.40296896595214116
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"model_type": "RandomForestRegressor",
"n_estimators": 100,
"criterion": "squared_error",
"min_samples_split": 2,
"min_samples_leaf": 1,
"min_weight_fraction_leaf": 0,
"max_features": "auto",
"min_impurity_decrease": 0,
"bootstrap": true,
"oob_score": false,
"verbose": 0,
"warm_start": false,
"ccp_alpha": 0
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"r2_score": 0.42034590382281145
}
Binary file modified demo-project/data/session_store.db
Binary file not shown.
6 changes: 6 additions & 0 deletions demo-project/src/demo_project/pipelines/reporting/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This is a boilerplate pipeline 'reporting'
generated using Kedro 0.18.1
"""
import matplotlib.pyplot as plt
import pandas as pd
import PIL
import plotly.express as px
Expand Down Expand Up @@ -87,3 +88,8 @@ def make_price_analysis_image(model_input_table: pd.DataFrame) -> PIL.Image:

pil_table = DrawTable(analysis_df)
return pil_table.image


def create_matplotlib_chart(companies: pd.DataFrame):
plt.plot([1, 2, 3], [4, 5, 6])
return plt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from kedro.pipeline import Pipeline, node, pipeline

from demo_project.pipelines.reporting.nodes import (
create_matplotlib_chart,
make_cancel_policy_bar_chart,
make_price_analysis_image,
make_price_histogram,
Expand All @@ -31,6 +32,11 @@ def create_pipeline(**kwargs) -> Pipeline:
inputs="prm_shuttle_company_reviews",
outputs="cancellation_policy_grid",
),
node(
func=create_matplotlib_chart,
inputs="prm_shuttle_company_reviews",
outputs="matplotlib_image",
),
],
inputs=["prm_shuttle_company_reviews"],
namespace="reporting",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
kedro[plotly.PlotlyDataSet, plotly.JSONDataSet]==0.18.1
matplotlib==3.5.0
kedro[plotly.PlotlyDataSet, plotly.JSONDataSet, matplotlib.MatplotlibWriter]==0.18.1
pillow==9.0.1
3 changes: 1 addition & 2 deletions demo-project/src/docker_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
kedro[pandas.CSVDataSet,pandas.ExcelDataSet, pandas.ParquetDataSet, plotly.PlotlyDataSet]==0.18.1
kedro[pandas.CSVDataSet,pandas.ExcelDataSet, pandas.ParquetDataSet, plotly.PlotlyDataSet, matplotlib.MatplotlibWriter]==0.18.1
scikit-learn~=1.0
pillow~=9.0
matplotlib==3.5.0
1 change: 1 addition & 0 deletions package/kedro_viz/api/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class DataNodeMetadataAPIResponse(BaseAPIResponse):
filepath: str
type: str
plot: Optional[Dict]
image: Optional[str]
tracking_data: Optional[Dict]
run_command: Optional[str]

Expand Down
2 changes: 1 addition & 1 deletion package/kedro_viz/data_access/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def create_modular_pipelines_tree_for_registered_pipeline(
# Add the modular pipeline node to the global list of nodes if necessary
# and update the list of pipelines it belongs to.
# N.B. Ideally we will have different modular pipeline nodes for
# different registered pipelinesm, but that requires a bit of a bigger refactor
# different registered pipelines, but that requires a bit of a bigger refactor
# so we will just use the same node for now.
self.nodes.add_node(modular_pipeline_node)
self.nodes.get_node_by_id(modular_pipeline_node.id).pipelines = {
Expand Down
30 changes: 28 additions & 2 deletions package/kedro_viz/models/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""`kedro_viz.models.graph` defines data models to represent Kedro entities in a viz graph."""
# pylint: disable=protected-access
import abc
import base64
import hashlib
import inspect
import json
Expand Down Expand Up @@ -481,6 +482,13 @@ def is_plot_node(self):
== "kedro.extras.datasets.plotly.json_dataset.JSONDataSet"
)

def is_image_node(self):
"""Check if the current node is a matplotlib image node."""
return (
self.dataset_type
== "kedro.extras.datasets.matplotlib.matplotlib_writer.MatplotlibWriter"
)

def is_metric_node(self):
"""Check if the current node is a metrics node."""
return (
Expand Down Expand Up @@ -559,6 +567,10 @@ class DataNodeMetadata(GraphNodeMetadata):
# currently only applicable for PlotlyDataSet
plot: Optional[Dict] = field(init=False)

# the optional image data if the underlying dataset has a image.
# currently only applicable for matplotlib.MatplotlibWriter
image: Optional[str] = field(init=False)

tracking_data: Optional[Dict] = field(init=False)

# command to run the pipeline to this data node
Expand All @@ -578,18 +590,32 @@ def __post_init__(self, data_node: DataNode):
from kedro.extras.datasets.plotly.plotly_dataset import PlotlyDataSet

dataset = cast(Union[PlotlyDataSet, PlotlyJSONDataSet], dataset)
if not dataset._exists():
if not dataset.exists():
return

load_path = get_filepath_str(dataset._get_load_path(), dataset._protocol)
with dataset._fs.open(load_path, **dataset._fs_open_args_load) as fs_file:
self.plot = json.load(fs_file)

if data_node.is_image_node():
from kedro.extras.datasets.matplotlib.matplotlib_writer import (
MatplotlibWriter,
)

dataset = cast(MatplotlibWriter, dataset)
if not dataset.exists():
return

load_path = get_filepath_str(dataset._get_load_path(), dataset._protocol)
with open(load_path, "rb") as img_file:
base64_bytes = base64.b64encode(img_file.read())
self.image = base64_bytes.decode("utf-8")

if data_node.is_tracking_node():
from kedro.extras.datasets.tracking.json_dataset import JSONDataSet
from kedro.extras.datasets.tracking.metrics_dataset import MetricsDataSet

if not dataset._exists() or self.filepath is None:
if not dataset.exists() or self.filepath is None:
return

dataset = cast(Union[JSONDataSet, MetricsDataSet], dataset)
Expand Down
58 changes: 47 additions & 11 deletions package/tests/test_models/test_graph/test_graph_nodes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# pylint: disable=too-many-public-methods
import base64
import datetime
import json
import shutil
import time
from functools import partial
from pathlib import Path
from textwrap import dedent
from unittest.mock import MagicMock, call, patch
from unittest.mock import MagicMock, call, mock_open, patch

import pandas as pd
import pytest
Expand All @@ -33,7 +34,7 @@


def import_mock(name, *args):
if name.startswith("plotly"):
if name.startswith("matplotlib"):
return MagicMock()
return orig_import(name, *args)

Expand Down Expand Up @@ -384,9 +385,8 @@ def test_partitioned_data_node_metadata(self):
data_node_metadata = DataNodeMetadata(data_node=data_node)
assert data_node_metadata.filepath == "partitioned/"

@patch("builtins.__import__", side_effect=import_mock)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this removed? Seems to me that we should be able to run the tests without plotly installed (related to the point of removing plotly from requirements.txt).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure. @rashidakanchwala? Should I add it back in?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I removed it, because it wasn't actually being used. :S

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is used even though patched_import is not explicitly mentioned in the test. When you @patch it decorates the test function, and doing so also passes on an argument to the test function.

But it turns out it's not needed for another reason: because plotly is a requirement of kedro-viz.

@patch("json.load")
def test_plotly_data_node_metadata(self, patched_json_load, patched_import):
def test_plotly_data_node_metadata(self, patched_json_load):
mock_plot_data = {
"data": [
{
Expand All @@ -399,16 +399,17 @@ def test_plotly_data_node_metadata(self, patched_json_load, patched_import):
patched_json_load.return_value = mock_plot_data
plotly_data_node = MagicMock()
plotly_data_node.is_plot_node.return_value = True
plotly_data_node.is_metric_node.return_value = False
plotly_data_node.is_image_node.return_value = False
plotly_data_node.is_tracking_node.return_value = False
plotly_node_metadata = DataNodeMetadata(data_node=plotly_data_node)
assert plotly_node_metadata.plot == mock_plot_data

@patch("builtins.__import__", side_effect=import_mock)
def test_plotly_data_node_dataset_not_exist(self, patched_import):
def test_plotly_data_node_dataset_not_exist(self):
plotly_data_node = MagicMock()
plotly_data_node.is_plot_node.return_value = True
plotly_data_node.is_metric_node.return_value = False
plotly_data_node.kedro_obj._exists.return_value = False
plotly_data_node.is_image_node.return_value = False
plotly_data_node.is_tracking_node.return_value = False
plotly_data_node.kedro_obj.exists.return_value = False
plotly_node_metadata = DataNodeMetadata(data_node=plotly_data_node)
assert not hasattr(plotly_node_metadata, "plot")

Expand All @@ -426,10 +427,40 @@ def test_plotly_json_dataset_node_metadata(self, patched_json_load):
patched_json_load.return_value = mock_plot_data
plotly_json_dataset_node = MagicMock()
plotly_json_dataset_node.is_plot_node.return_value = True
plotly_json_dataset_node.is_metric_node.return_value = False
plotly_json_dataset_node.is_image_node.return_value = False
plotly_json_dataset_node.is_tracking_node.return_value = False
plotly_node_metadata = DataNodeMetadata(data_node=plotly_json_dataset_node)
assert plotly_node_metadata.plot == mock_plot_data

@patch("builtins.__import__", side_effect=import_mock)
@patch(
"builtins.open",
new_callable=mock_open,
read_data=base64.b64decode(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAA"
"AAC0lEQVQYV2NgYAAAAAMAAWgmWQ0AAAAASUVORK5CYII="
),
)
def test_image_data_node_metadata(self, patched_base64, patched_import):
image_dataset_node = MagicMock()
base64_encoded_string = (
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAA"
"AAC0lEQVQYV2NgYAAAAAMAAWgmWQ0AAAAASUVORK5CYII="
)
image_dataset_node.is_image_node.return_value = True
image_dataset_node.is_plot_node.return_value = False
image_node_metadata = DataNodeMetadata(data_node=image_dataset_node)
assert image_node_metadata.image == base64_encoded_string

@patch("builtins.__import__", side_effect=import_mock)
def test_image_data_node_dataset_not_exist(self, patched_import):
image_dataset_node = MagicMock()
image_dataset_node.is_image_node.return_value = True
image_dataset_node.is_plot_node.return_value = False
image_dataset_node.kedro_obj.exists.return_value = False
image_node_metadata = DataNodeMetadata(data_node=image_dataset_node)
assert not hasattr(image_node_metadata, "image")

@patch("kedro_viz.models.graph.DataNodeMetadata.load_versioned_tracking_data")
@patch("kedro_viz.models.graph.DataNodeMetadata.load_latest_tracking_data")
@patch("kedro_viz.models.graph.DataNodeMetadata.create_metrics_plot")
Expand Down Expand Up @@ -470,6 +501,7 @@ def test_metrics_data_node_metadata(
patched_metrics_plot.return_value = mock_plot_data
metrics_data_node = MagicMock()
metrics_data_node.is_plot_node.return_value = False
metrics_data_node.is_image_node.return_value = False
metrics_data_node.is_tracking_node.return_value = True
metrics_data_node.is_metric_node.return_value = True
metrics_node_metadata = DataNodeMetadata(data_node=metrics_data_node)
Expand All @@ -490,6 +522,7 @@ def test_json_data_node_metadata(
patched_latest_json.return_value = mock_json_data
json_data_node = MagicMock()
json_data_node.is_plot_node.return_value = False
json_data_node.is_image_node.return_value = False
json_data_node.is_tracking_node.return_value = True
json_data_node.is_metric_node.return_value = False
json_node_metadata = DataNodeMetadata(data_node=json_data_node)
Expand All @@ -499,8 +532,9 @@ def test_json_data_node_metadata(
def test_metrics_data_node_metadata_dataset_not_exist(self):
metrics_data_node = MagicMock()
metrics_data_node.is_plot_node.return_value = False
metrics_data_node.is_image_node.return_value = False
metrics_data_node.is_metric_node.return_value = True
metrics_data_node.kedro_obj._exists.return_value = False
metrics_data_node.kedro_obj.exists.return_value = False
metrics_node_metadata = DataNodeMetadata(data_node=metrics_data_node)
assert not hasattr(metrics_node_metadata, "metrics")
assert not hasattr(metrics_node_metadata, "plot")
Expand All @@ -513,6 +547,7 @@ def test_data_node_metadata_latest_tracking_data_not_exist(
patched_latest_tracking_data.return_value = None
tracking_data_node = MagicMock()
tracking_data_node.is_plot_node.return_value = False
tracking_data_node.is_image_node.return_value = False
tracking_data_node.is_metric_node.return_value = True
tracking_data_node_metadata = DataNodeMetadata(data_node=tracking_data_node)
assert not hasattr(tracking_data_node_metadata, "metrics")
Expand All @@ -534,6 +569,7 @@ def test_tracking_data_node_metadata_versioned_dataset_not_exist(
patched_data_loader.return_value = {}
tracking_data_node = MagicMock()
tracking_data_node.is_plot_node.return_value = False
tracking_data_node.is_image_node.return_value = False
tracking_data_node.is_metric_node.return_value = True
tracking_data_node_metadata = DataNodeMetadata(data_node=tracking_data_node)
assert tracking_data_node_metadata.tracking_data == mock_metrics_data
Expand Down
4 changes: 2 additions & 2 deletions src/actions/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ export function toggleSettingsModal(visible) {
};
}

export const TOGGLE_PLOT_MODAL = 'TOGGLE_PLOT_MODAL';
export const TOGGLE_METADATA_MODAL = 'TOGGLE_METADATA_MODAL';

/**
* Toggle whether to show the plot modal
* @param {Boolean} visible True if the modal is to be shown
*/
export function togglePlotModal(visible) {
return {
type: TOGGLE_PLOT_MODAL,
type: TOGGLE_METADATA_MODAL,
visible,
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ describe('PrimaryToolbar', () => {
visible: expect.objectContaining({
exportBtn: expect.any(Boolean),
exportModal: expect.any(Boolean),
plotModal: expect.any(Boolean),
metadataModal: expect.any(Boolean),
settingsModal: expect.any(Boolean),
labelBtn: expect.any(Boolean),
layerBtn: expect.any(Boolean),
Expand Down
4 changes: 2 additions & 2 deletions src/components/flowchart-wrapper/flowchart-wrapper.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import FlowChart from '../flowchart';
import PipelineWarning from '../pipeline-warning';
import LoadingIcon from '../icons/loading';
import MetaData from '../metadata';
import PlotlyModal from '../plotly-modal';
import MetadataModal from '../metadata-modal';
import Sidebar from '../sidebar';
import './flowchart-wrapper.css';

Expand All @@ -24,7 +24,7 @@ export const FlowChartWrapper = ({ loading }) => (
<LoadingIcon className="pipeline-wrapper__loading" visible={loading} />
</div>
<ExportModal />
<PlotlyModal />
<MetadataModal />
</div>
);

Expand Down
2 changes: 1 addition & 1 deletion src/components/global-toolbar/global-toolbar.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ describe('GlobalToolbar', () => {
miniMap: true,
miniMapBtn: true,
modularPipelineFocusMode: null,
plotModal: false,
metadataModal: false,
settingsModal: false,
sidebar: true,
},
Expand Down
Loading