Skip to content

Commit

Permalink
Running single table report on multi table data (or vice versa) resul…
Browse files Browse the repository at this point in the history
…ts in confusing error (#522)
  • Loading branch information
R-Palazzo authored Nov 17, 2023
1 parent 2eef3a5 commit bf5ccd2
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 0 deletions.
16 changes: 16 additions & 0 deletions sdmetrics/reports/base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@ def _validate_metadata_matches_data(self, real_data, synthetic_data, metadata):
)
raise ValueError(error_message)

def _validate_data_format(self, real_data, synthetic_data):
"""Validate that the real and synthetic data are pd.DataFrame for single table reports."""
is_real_dataframe = isinstance(real_data, pd.DataFrame)
is_synthetic_dataframe = isinstance(synthetic_data, pd.DataFrame)
if is_real_dataframe and is_synthetic_dataframe:
return

error_message = (
f'Single table report {self.__class__.__name__} expects real and synthetic data to be'
' pandas.DataFrame. If your real and synthetic data are dictionaries of tables, '
f'please use the multi-table {self.__class__.__name__} instead.'

)
raise ValueError(error_message)

def _validate(self, real_data, synthetic_data, metadata):
"""Validate the inputs.
Expand All @@ -64,6 +79,7 @@ def _validate(self, real_data, synthetic_data, metadata):
metadata (dict):
The metadata of the table.
"""
self._validate_data_format(real_data, synthetic_data)
self._validate_metadata_matches_data(real_data, synthetic_data, metadata)

def _handle_results(self, verbose):
Expand Down
24 changes: 24 additions & 0 deletions sdmetrics/reports/multi_table/base_multi_table_report.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Single table base property class."""
import pandas as pd

from sdmetrics.reports.base_report import BaseReport
from sdmetrics.visualization import set_plotly_config

Expand All @@ -15,6 +17,28 @@ def __init__(self):
super().__init__()
self.table_names = []

def _validate_data_format(self, real_data, synthetic_data):
"""Validate that the real and synthetic are dictionnaries of tables."""
is_real_dict = isinstance(real_data, dict)
is_synthetic_dict = isinstance(synthetic_data, dict)
if is_real_dict and is_synthetic_dict:
all_real_dataframes = all(
isinstance(table, pd.DataFrame) for table in real_data.values()
)
all_synthetic_dataframes = all(
isinstance(table, pd.DataFrame) for table in synthetic_data.values()
)
if all_real_dataframes and all_synthetic_dataframes:
return

error_message = (
f'Multi table report {self.__class__.__name__} expects real and synthetic data to be'
' dictionaries of pandas.DataFrame. If your real and synthetic data are pd.DataFrame,'
f' please use the single-table {self.__class__.__name__} instead.'
)

raise ValueError(error_message)

def _validate_relationships(self, real_data, synthetic_data, metadata):
"""Validate that the relationships are valid."""
for rel in metadata.get('relationships', []):
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/reports/multi_table/test_base_multi_table_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
import pytest

from sdmetrics.demos import load_demo
from sdmetrics.reports.multi_table.base_multi_table_report import BaseMultiTableReport


Expand All @@ -22,6 +23,25 @@ def test__init__(self):
assert report._properties == {}
assert report.table_names == []

def test__validate_data_format(self):
"""Test the ``_validate_data_format`` method.
This test checks that the method raises an error when the real and synthetic data are not
dictionnaries of pd.DataFrame.
"""
# Setup
base_report = BaseMultiTableReport()
real_data, synthetic_data, _ = load_demo('single_table')

# Run and Assert
expected_message = (
'Multi table report BaseMultiTableReport expects real and synthetic data to be '
'dictionaries of pandas.DataFrame. If your real and synthetic data are '
'pd.DataFrame, please use the single-table BaseMultiTableReport instead.'
)
with pytest.raises(ValueError, match=expected_message):
base_report._validate_data_format(real_data, synthetic_data)

def test__validate_relationships(self):
"""Test the ``_validate_relationships`` method."""
# Setup
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/reports/test_base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,30 @@
import pandas as pd
import pytest

from sdmetrics.demos import load_demo
from sdmetrics.reports.base_report import BaseReport


class TestBaseReport:
def test__validate_data_format(self):
"""Test the ``_validate_data_format`` method.
This test checks that the method raises an error when the real and synthetic data are not
pandas.DataFrame.
"""
# Setup
base_report = BaseReport()
real_data, synthetic_data, _ = load_demo('multi_table')

# Run and Assert
expected_message = (
'Single table report BaseReport expects real and synthetic data to be '
'pandas.DataFrame. If your real and synthetic data are dictionaries of '
'tables, please use the multi-table BaseReport instead.'
)
with pytest.raises(ValueError, match=expected_message):
base_report._validate_data_format(real_data, synthetic_data)

def test__validate_metadata_matches_data(self):
"""Test the ``_validate_metadata_matches_data`` method.
Expand Down

0 comments on commit bf5ccd2

Please sign in to comment.