Skip to content

Commit

Permalink
Add validate to BaseMultiTableSynthesizer (#1092)
Browse files Browse the repository at this point in the history
* Add fk validation

* Curate message

* Finish unit tests

* Address comments

* Fix error messaging

* Address comments
  • Loading branch information
pvk-developer authored and amontanez24 committed Dec 21, 2022
1 parent 2c76417 commit 34005e5
Show file tree
Hide file tree
Showing 2 changed files with 285 additions and 0 deletions.
79 changes: 79 additions & 0 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from collections import defaultdict
from copy import deepcopy

import pandas as pd

from sdv.single_table.copulas import GaussianCopulaSynthesizer
from sdv.single_table.errors import InvalidDataError


class BaseMultiTableSynthesizer:
Expand Down Expand Up @@ -82,3 +85,79 @@ def update_table_parameters(self, table_name, table_parameters):
def get_metadata(self):
"""Return the ``MultiTableMetadata`` for this synthesizer."""
return self.metadata

def _validate_foreign_keys(self, data):
error_msg = None
errors = []
for relation in self.metadata._relationships:
child_table = data.get(relation['child_table_name'])
parent_table = data.get(relation['parent_table_name'])
if isinstance(child_table, pd.DataFrame) and isinstance(parent_table, pd.DataFrame):
child_column = child_table[relation['child_foreign_key']]
parent_column = parent_table[relation['parent_primary_key']]
missing_values = child_column[~child_column.isin(parent_column)].unique()
if any(missing_values):
message = ', '.join(missing_values[:5].astype(str))
if len(missing_values) > 5:
message = f'({message}, + more)'
else:
message = f'({message})'

errors.append(
f"Error: foreign key column '{relation['child_foreign_key']}' contains "
f'unknown references: {message}. All the values in this column must '
'reference a primary key.'
)
if errors:
error_msg = 'Relationships:\n'
error_msg += '\n'.join(errors)

return error_msg

def validate(self, data):
"""Validate data.
Args:
data (dict):
A dictionary with key as table name and ``pandas.DataFrame`` as value to validate.
Raises:
ValueError:
Raised when data is not of type pd.DataFrame.
InvalidDataError:
Raised if:
* foreign key does not belong to a primay key
* data columns don't match metadata
* keys have missing values
* primary or alternate keys are not unique
* context columns vary for a sequence key
* values of a column don't satisfy their sdtype
"""
errors = []
missing_tables = set(self.metadata._tables) - set(data)
if missing_tables:
errors.append(f'The provided data is missing the tables {missing_tables}.')

for table_name, table_data in data.items():
try:
self._table_synthesizers[table_name].validate(table_data)

except InvalidDataError as error:
error_msg = f"Table: '{table_name}'"
for _error in error.errors:
error_msg += f'\nError: {_error}'

errors.append(error_msg)

except ValueError as error:
errors.append(str(error))

except KeyError:
continue

foreign_key_errors = self._validate_foreign_keys(data)
if foreign_key_errors:
errors.append(foreign_key_errors)

if errors:
raise InvalidDataError(errors)
206 changes: 206 additions & 0 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import re
from collections import defaultdict
from unittest.mock import Mock, call

import numpy as np
import pandas as pd
import pytest

from sdv.multi_table.base import BaseMultiTableSynthesizer
from sdv.single_table.copulas import GaussianCopulaSynthesizer
from sdv.single_table.errors import InvalidDataError
from tests.utils import get_multi_table_metadata


Expand Down Expand Up @@ -136,3 +142,203 @@ def test_get_metadata(self):

# Assert
assert metadata == result

def test__validate_foreign_keys(self):
"""Test that when the data matches as expected there are no errors."""
# Setup
metadata = get_multi_table_metadata()
data = {
'nesreca': pd.DataFrame({
'id_nesreca': np.arange(10),
'upravna_enota': np.arange(10),
}),
'oseba': pd.DataFrame({
'upravna_enota': np.arange(10),
'id_nesreca': np.arange(10),
}),
'upravna_enota': pd.DataFrame({
'id_upravna_enota': np.arange(10),
}),
}
instance = BaseMultiTableSynthesizer(metadata)

# Run
result = instance._validate_foreign_keys(data)

# Assert
assert result is None

def test__validate_foreign_keys_missing_keys(self):
"""Test that errors are being returned.
When the values of the foreign keys are not within the values of the parent
primary key, a list of errors must be returned indicating the values that are missing.
"""
# Setup
metadata = get_multi_table_metadata()
data = {
'nesreca': pd.DataFrame({
'id_nesreca': np.arange(0, 20, 2),
'upravna_enota': np.arange(10),
}),
'oseba': pd.DataFrame({
'upravna_enota': np.arange(10, 20),
'id_nesreca': np.arange(10),
}),
'upravna_enota': pd.DataFrame({
'id_upravna_enota': np.arange(10),
}),
}
instance = BaseMultiTableSynthesizer(metadata)

# Run
result = instance._validate_foreign_keys(data)

# Assert
missing_upravna_enota = (
'Relationships:\n'
"Error: foreign key column 'upravna_enota' contains unknown references: "
'(10, 11, 12, 13, 14, + more). '
'All the values in this column must reference a primary key.\n'
"Error: foreign key column 'id_nesreca' contains unknown references: (1, 3, 5, 7, 9)."
' All the values in this column must reference a primary key.'
)
assert result == missing_upravna_enota

def test_validate(self):
"""Test that no error is being raised when the data is valid."""
# Setup
metadata = get_multi_table_metadata()
data = {
'nesreca': pd.DataFrame({
'id_nesreca': np.arange(10),
'upravna_enota': np.arange(10),
}),
'oseba': pd.DataFrame({
'upravna_enota': np.arange(10),
'id_nesreca': np.arange(10),
}),
'upravna_enota': pd.DataFrame({
'id_upravna_enota': np.arange(10),
}),
}

instance = BaseMultiTableSynthesizer(metadata)

# Run and Assert
instance.validate(data)

def test_validate_missing_table(self):
"""Test that an error is being raised when there is a missing table in the dictionary."""
# Setup
metadata = get_multi_table_metadata()
data = {
'nesrecas': pd.DataFrame({
'id_nesreca': np.arange(10),
'upravna_enota': np.arange(10),
}),
'oseba': pd.DataFrame({
'upravna_enota': np.arange(10),
'id_nesreca': np.arange(10),
}),
'upravna_enota': pd.DataFrame({
'id_upravna_enota': np.arange(10),
}),
}

instance = BaseMultiTableSynthesizer(metadata)

# Run and Assert
error_msg = "The provided data is missing the tables {'nesreca'}."
with pytest.raises(InvalidDataError, match=error_msg):
instance.validate(data)

def test_validate_data_is_not_dataframe(self):
"""Test that an error is being raised when the data is not a dataframe."""
# Setup
metadata = get_multi_table_metadata()
data = {
'nesreca': pd.Series({
'id_nesreca': np.arange(10),
'upravna_enota': np.arange(10),
}),
'oseba': pd.DataFrame({
'upravna_enota': np.arange(10),
'id_nesreca': np.arange(10),
}),
'upravna_enota': pd.DataFrame({
'id_upravna_enota': np.arange(10),
}),
}

instance = BaseMultiTableSynthesizer(metadata)

# Run and Assert
error_msg = "Data must be a DataFrame, not a <class 'pandas.core.series.Series'>."
with pytest.raises(InvalidDataError, match=error_msg):
instance.validate(data)

def test_validate_data_does_not_match(self):
"""Test that an error is being raised when the data does not match the metadata."""
# Setup
metadata = get_multi_table_metadata()
data = {
'nesreca': pd.DataFrame({
'id_nesreca': np.arange(10).astype(str),
'upravna_enota': np.arange(10),
}),
'oseba': pd.DataFrame({
'upravna_enota': np.arange(10).astype(str),
'id_nesreca': np.arange(10),
}),
'upravna_enota': pd.DataFrame({
'id_upravna_enota': np.arange(10).astype(str),
}),
}

instance = BaseMultiTableSynthesizer(metadata)

# Run and Assert
error_msg = re.escape(
'The provided data does not match the metadata:\n'
"Table: 'nesreca'\n"
"Error: Invalid values found for numerical column 'id_nesreca': ['0', '1', '2', "
"'+ 7 more']."
"\n\nTable: 'oseba'\n"
"Error: Invalid values found for numerical column 'upravna_enota': ['0', '1', '2', "
"'+ 7 more']."
"\n\nTable: 'upravna_enota'\n"
"Error: Invalid values found for numerical column 'id_upravna_enota': ['0', '1', '2', "
"'+ 7 more']."
)
with pytest.raises(InvalidDataError, match=error_msg):
instance.validate(data)

def test_validate_missing_foreign_keys(self):
"""Test that errors are being raised when there are missing foreign keys."""
# Setup
metadata = get_multi_table_metadata()
data = {
'nesreca': pd.DataFrame({
'id_nesreca': np.arange(0, 20, 2),
'upravna_enota': np.arange(10),
}),
'oseba': pd.DataFrame({
'upravna_enota': np.arange(10),
'id_nesreca': np.arange(10),
}),
'upravna_enota': pd.DataFrame({
'id_upravna_enota': np.arange(10),
}),
}
instance = BaseMultiTableSynthesizer(metadata)

# Run and Assert
error_msg = re.escape(
'The provided data does not match the metadata:\n'
'Relationships:\n'
"Error: foreign key column 'id_nesreca' contains unknown references: (1, 3, 5, 7, 9). "
'All the values in this column must reference a primary key.'
)
with pytest.raises(InvalidDataError, match=error_msg):
instance.validate(data)

0 comments on commit 34005e5

Please sign in to comment.