Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace invalid symbols in the labels for metadata visualization #1670

Merged
merged 9 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions sdv/metadata/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,12 @@ def _get_graphviz_extension(filepath):
return None, None


def _replace_special_characters(string):
return string.replace('<', '\<').replace('>', '\>') # noqa: W605


def visualize_graph(nodes, edges, filepath=None):
"""Plot metadata usign graphviz.
"""Plot metadata using graphviz.

Try to generate a plot using graphviz.
If a ``filepath`` is provided save the output into a file.
Expand Down Expand Up @@ -105,10 +109,10 @@ def visualize_graph(nodes, edges, filepath=None):
)

for name, label in nodes.items():
digraph.node(name, label=label)
digraph.node(name, label=_replace_special_characters(label))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it doesn't address the problem.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. It seems like it works if you replace '>' with '\>'. Then in the output graph the label looks correct
Screenshot 2023-11-08 at 9 51 51 PM

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note, this approach removes the backslash from the column name. So if a column has \< in it, the output will be <.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's ok since prior to this change if they had '\>' in a label, it would only show '>' anyway.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About this, can we add an integration test where we fit and sample a synthesizer after the visualize() to ensure we're not breaking the metadata validation or the fit and sample with the change


for parent, child, label in edges:
digraph.edge(parent, child, label=label, arrowhead='oinv')
digraph.edge(parent, child, label=_replace_special_characters(label), arrowhead='oinv')

if filename:
digraph.render(filename=filename, cleanup=True, format=graphviz_extension)
Expand Down
31 changes: 31 additions & 0 deletions tests/integration/metadata/test_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pandas as pd

from sdv.metadata import MultiTableMetadata, SingleTableMetadata


def test_visualize_graph_for_single_table():
"""Test it runs when a column name contains `>`."""
# Setup
data = pd.DataFrame({'>': ['a', 'b', 'c']})
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)

# Run
metadata.visualize()


def test_visualize_graph_for_multi_table():
"""Test it runs when a column name contains `>`."""
# Setup
data1 = pd.DataFrame({'>': ['a', 'b', 'c']})
data2 = pd.DataFrame({'>': ['a', 'b', 'c']})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we test both symbols? and maybe the symbols mixed with text

tables = {'1': data1, '2': data2}
metadata = MultiTableMetadata()
metadata.detect_from_dataframes(tables)
metadata.update_column('1', '>', sdtype='id')
metadata.update_column('2', '>', sdtype='id')
metadata.set_primary_key('1', '>')
metadata.add_relationship('1', '2', '>', '>')

# Run
metadata.visualize()