Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running single table report on multi table data (or vice versa) results in confusing error #522

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions sdmetrics/reports/base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@ def _validate_metadata_matches_data(self, real_data, synthetic_data, metadata):
)
raise ValueError(error_message)

def _validate_data_format(self, real_data, synthetic_data):
"""Validate that the real and synthetic data are pd.DataFrame for single table reports."""
is_real_dataframe = isinstance(real_data, pd.DataFrame)
is_synthetic_dataframe = isinstance(synthetic_data, pd.DataFrame)
if is_real_dataframe and is_synthetic_dataframe:
return

error_message = (
f'Single table report {self.__class__.__name__} expects real and synthetic data to be'
' pandas.DataFrame. If your real and synthetic data are dictionaries of tables, '
f'please use the multi-table {self.__class__.__name__} instead.'

)
raise ValueError(error_message)

def _validate(self, real_data, synthetic_data, metadata):
"""Validate the inputs.

Expand All @@ -64,6 +79,7 @@ def _validate(self, real_data, synthetic_data, metadata):
metadata (dict):
The metadata of the table.
"""
self._validate_data_format(real_data, synthetic_data)
self._validate_metadata_matches_data(real_data, synthetic_data, metadata)

def _handle_results(self, verbose):
Expand Down
24 changes: 24 additions & 0 deletions sdmetrics/reports/multi_table/base_multi_table_report.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Single table base property class."""
import pandas as pd

from sdmetrics.reports.base_report import BaseReport
from sdmetrics.visualization import set_plotly_config

Expand All @@ -15,6 +17,28 @@ def __init__(self):
super().__init__()
self.table_names = []

def _validate_data_format(self, real_data, synthetic_data):
"""Validate that the real and synthetic are dictionnaries of tables."""
is_real_dict = isinstance(real_data, dict)
is_synthetic_dict = isinstance(synthetic_data, dict)
if is_real_dict and is_synthetic_dict:
all_real_dataframes = all(
isinstance(table, pd.DataFrame) for table in real_data.values()
)
all_synthetic_dataframes = all(
isinstance(table, pd.DataFrame) for table in synthetic_data.values()
)
if all_real_dataframes and all_synthetic_dataframes:
return

error_message = (
f'Multi table report {self.__class__.__name__} expects real and synthetic data to be'
' dictionaries of pandas.DataFrame. If your real and synthetic data are pd.DataFrame,'
f' please use the single-table {self.__class__.__name__} instead.'
)

raise ValueError(error_message)

def _validate_relationships(self, real_data, synthetic_data, metadata):
"""Validate that the relationships are valid."""
for rel in metadata.get('relationships', []):
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/reports/multi_table/test_base_multi_table_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
import pytest

from sdmetrics.demos import load_demo
from sdmetrics.reports.multi_table.base_multi_table_report import BaseMultiTableReport


Expand All @@ -22,6 +23,25 @@ def test__init__(self):
assert report._properties == {}
assert report.table_names == []

def test__validate_data_format(self):
"""Test the ``_validate_data_format`` method.

This test checks that the method raises an error when the real and synthetic data are not
dictionnaries of pd.DataFrame.
"""
# Setup
base_report = BaseMultiTableReport()
real_data, synthetic_data, _ = load_demo('single_table')

# Run and Assert
expected_message = (
'Multi table report BaseMultiTableReport expects real and synthetic data to be '
'dictionaries of pandas.DataFrame. If your real and synthetic data are '
'pd.DataFrame, please use the single-table BaseMultiTableReport instead.'
)
with pytest.raises(ValueError, match=expected_message):
base_report._validate_data_format(real_data, synthetic_data)

def test__validate_relationships(self):
"""Test the ``_validate_relationships`` method."""
# Setup
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/reports/test_base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,30 @@
import pandas as pd
import pytest

from sdmetrics.demos import load_demo
from sdmetrics.reports.base_report import BaseReport


class TestBaseReport:
def test__validate_data_format(self):
"""Test the ``_validate_data_format`` method.

This test checks that the method raises an error when the real and synthetic data are not
pandas.DataFrame.
"""
# Setup
base_report = BaseReport()
real_data, synthetic_data, _ = load_demo('multi_table')

# Run and Assert
expected_message = (
'Single table report BaseReport expects real and synthetic data to be '
'pandas.DataFrame. If your real and synthetic data are dictionaries of '
'tables, please use the multi-table BaseReport instead.'
)
with pytest.raises(ValueError, match=expected_message):
base_report._validate_data_format(real_data, synthetic_data)

def test__validate_metadata_matches_data(self):
"""Test the ``_validate_metadata_matches_data`` method.

Expand Down
Loading