-
Notifications
You must be signed in to change notification settings - Fork 324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add validate to BaseMultiTableSynthesizer
#1092
Changes from 5 commits
57cdb06
fcec8da
53424b2
553b308
172a25f
bd0f932
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,10 @@ | |
from collections import defaultdict | ||
from copy import deepcopy | ||
|
||
import pandas as pd | ||
|
||
from sdv.single_table.copulas import GaussianCopulaSynthesizer | ||
from sdv.single_table.errors import InvalidDataError | ||
|
||
|
||
class BaseMultiTableSynthesizer: | ||
|
@@ -82,3 +85,80 @@ def update_table_parameters(self, table_name, table_parameters): | |
def get_metadata(self): | ||
"""Return the ``MultiTableMetadata`` for this synthesizer.""" | ||
return self.metadata | ||
|
||
def _validate_foreign_keys(self, data): | ||
error_msg = None | ||
errors = [] | ||
for relation in self.metadata._relationships: | ||
child_table = data.get(relation['child_table_name']) | ||
parent_table = data.get(relation['parent_table_name']) | ||
if isinstance(child_table, pd.DataFrame) and isinstance(parent_table, pd.DataFrame): | ||
child_column = child_table[relation['child_foreign_key']] | ||
parent_column = parent_table[relation['parent_primary_key']] | ||
missing_values = child_column[~child_column.isin(parent_column)].unique() | ||
if any(missing_values): | ||
message = ', '.join(missing_values[:5].astype(str)) | ||
if len(missing_values) > 5: | ||
message = f'({message}, + more).' | ||
else: | ||
message = f'({message}).' | ||
|
||
errors.append( | ||
f"Error: foreign key column '{relation['child_foreign_key']}' contains " | ||
f'unknown references: {message} All the values in this column must ' | ||
amontanez24 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
'reference a primary key.' | ||
) | ||
if errors: | ||
error_msg = 'Relationships:\n' | ||
error_msg += '\n'.join(errors) | ||
|
||
return error_msg | ||
|
||
def validate(self, data): | ||
"""Validate data. | ||
|
||
Args: | ||
data (dict): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd call it tables. Data makes me think it's just a dataframe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm fine with tables, I just followed the issue specification, @amontanez24 green light on the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @npatki I think this was set to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see the same issue with To me, |
||
A dictionary with key as table name and ``pandas.DataFrame`` as value to validate. | ||
|
||
Raises: | ||
ValueError: | ||
Raised when data is not of type pd.DataFrame. | ||
InvalidDataError: | ||
Raised if: | ||
* foreign key does not belong to a primay key | ||
* data columns don't match metadata | ||
* keys have missing values | ||
* primary or alternate keys are not unique | ||
* context columns vary for a sequence key | ||
* values of a column don't satisfy their sdtype | ||
""" | ||
errors = [] | ||
missing_tables = set(self.metadata._tables) - set(data) | ||
if missing_tables: | ||
errors.append(f'The provided data is missing the tables {missing_tables}.') | ||
|
||
for table_name, table_data in data.items(): | ||
try: | ||
self._table_synthesizers[table_name].validate(table_data) | ||
|
||
except (InvalidDataError, ValueError) as error: | ||
if isinstance(error, InvalidDataError): | ||
error_msg = f"Table: '{table_name}'" | ||
for _error in error.errors: | ||
error_msg += f'\nError: {_error}' | ||
|
||
else: | ||
error_msg = str(error) | ||
|
||
errors.append(error_msg) | ||
|
||
except KeyError: | ||
continue | ||
|
||
foreign_key_errors = self._validate_foreign_keys(data) | ||
if foreign_key_errors: | ||
errors.append(foreign_key_errors) | ||
|
||
if errors: | ||
raise InvalidDataError(errors) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,14 @@ | ||
import re | ||
from collections import defaultdict | ||
from unittest.mock import Mock, call | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
||
from sdv.multi_table.base import BaseMultiTableSynthesizer | ||
from sdv.single_table.copulas import GaussianCopulaSynthesizer | ||
from sdv.single_table.errors import InvalidDataError | ||
from tests.utils import get_multi_table_metadata | ||
|
||
|
||
|
@@ -136,3 +142,203 @@ def test_get_metadata(self): | |
|
||
# Assert | ||
assert metadata == result | ||
|
||
def test__validate_foreign_keys(self): | ||
"""Test that when the data matches as expected there are no errors.""" | ||
# Setup | ||
metadata = get_multi_table_metadata() | ||
data = { | ||
'nesreca': pd.DataFrame({ | ||
'id_nesreca': np.arange(10), | ||
'upravna_enota': np.arange(10), | ||
}), | ||
'oseba': pd.DataFrame({ | ||
'upravna_enota': np.arange(10), | ||
'id_nesreca': np.arange(10), | ||
}), | ||
'upravna_enota': pd.DataFrame({ | ||
'id_upravna_enota': np.arange(10), | ||
}), | ||
} | ||
instance = BaseMultiTableSynthesizer(metadata) | ||
|
||
# Run | ||
result = instance._validate_foreign_keys(data) | ||
|
||
# Assert | ||
assert result is None | ||
|
||
def test__validate_foreign_keys_missing_keys(self): | ||
"""Test that errors are being returned. | ||
|
||
When the values of the foreign keys are not within the values of the parent | ||
primary key, a list of errors must be returned indicating the values that are missing. | ||
""" | ||
# Setup | ||
metadata = get_multi_table_metadata() | ||
data = { | ||
'nesreca': pd.DataFrame({ | ||
'id_nesreca': np.arange(0, 20, 2), | ||
'upravna_enota': np.arange(10), | ||
}), | ||
'oseba': pd.DataFrame({ | ||
'upravna_enota': np.arange(10, 20), | ||
'id_nesreca': np.arange(10), | ||
}), | ||
'upravna_enota': pd.DataFrame({ | ||
'id_upravna_enota': np.arange(10), | ||
}), | ||
} | ||
instance = BaseMultiTableSynthesizer(metadata) | ||
|
||
# Run | ||
result = instance._validate_foreign_keys(data) | ||
|
||
# Assert | ||
missing_upravna_enota = ( | ||
'Relationships:\n' | ||
"Error: foreign key column 'upravna_enota' contains unknown references: " | ||
'(10, 11, 12, 13, 14, + more). ' | ||
'All the values in this column must reference a primary key.\n' | ||
"Error: foreign key column 'id_nesreca' contains unknown references: (1, 3, 5, 7, 9)." | ||
' All the values in this column must reference a primary key.' | ||
) | ||
assert result == missing_upravna_enota | ||
|
||
def test_validate(self): | ||
"""Test that no error is being raised when the data is valid.""" | ||
# Setup | ||
metadata = get_multi_table_metadata() | ||
data = { | ||
'nesreca': pd.DataFrame({ | ||
'id_nesreca': np.arange(10), | ||
'upravna_enota': np.arange(10), | ||
}), | ||
'oseba': pd.DataFrame({ | ||
'upravna_enota': np.arange(10), | ||
'id_nesreca': np.arange(10), | ||
}), | ||
'upravna_enota': pd.DataFrame({ | ||
'id_upravna_enota': np.arange(10), | ||
}), | ||
} | ||
|
||
instance = BaseMultiTableSynthesizer(metadata) | ||
|
||
# Run and Assert | ||
instance.validate(data) | ||
|
||
def test_validate_missing_table(self): | ||
"""Test that an error is being raised when there is a missing table in the dictionary.""" | ||
# Setup | ||
metadata = get_multi_table_metadata() | ||
data = { | ||
'nesrecas': pd.DataFrame({ | ||
'id_nesreca': np.arange(10), | ||
'upravna_enota': np.arange(10), | ||
}), | ||
'oseba': pd.DataFrame({ | ||
'upravna_enota': np.arange(10), | ||
'id_nesreca': np.arange(10), | ||
}), | ||
'upravna_enota': pd.DataFrame({ | ||
'id_upravna_enota': np.arange(10), | ||
}), | ||
} | ||
|
||
instance = BaseMultiTableSynthesizer(metadata) | ||
|
||
# Run and Assert | ||
error_msg = "The provided data is missing the tables {'nesreca'}." | ||
with pytest.raises(InvalidDataError, match=error_msg): | ||
instance.validate(data) | ||
|
||
def test_validate_data_is_not_dataframe(self): | ||
"""Test that an error is being raised when the data is not a dataframe.""" | ||
# Setup | ||
metadata = get_multi_table_metadata() | ||
data = { | ||
'nesreca': pd.Series({ | ||
'id_nesreca': np.arange(10), | ||
'upravna_enota': np.arange(10), | ||
}), | ||
'oseba': pd.DataFrame({ | ||
'upravna_enota': np.arange(10), | ||
'id_nesreca': np.arange(10), | ||
}), | ||
'upravna_enota': pd.DataFrame({ | ||
'id_upravna_enota': np.arange(10), | ||
}), | ||
} | ||
|
||
instance = BaseMultiTableSynthesizer(metadata) | ||
|
||
# Run and Assert | ||
error_msg = "Data must be a DataFrame, not a <class 'pandas.core.series.Series'>." | ||
with pytest.raises(InvalidDataError, match=error_msg): | ||
instance.validate(data) | ||
|
||
def test_validate_data_does_not_match(self): | ||
"""Test that an error is being raised when the data does not match the metadata.""" | ||
# Setup | ||
metadata = get_multi_table_metadata() | ||
data = { | ||
'nesreca': pd.DataFrame({ | ||
'id_nesreca': np.arange(10).astype(str), | ||
'upravna_enota': np.arange(10), | ||
}), | ||
'oseba': pd.DataFrame({ | ||
'upravna_enota': np.arange(10).astype(str), | ||
'id_nesreca': np.arange(10), | ||
}), | ||
'upravna_enota': pd.DataFrame({ | ||
'id_upravna_enota': np.arange(10).astype(str), | ||
}), | ||
} | ||
|
||
instance = BaseMultiTableSynthesizer(metadata) | ||
|
||
# Run and Assert | ||
error_msg = re.escape( | ||
'The provided data does not match the metadata:\n' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This error message is a little weird. Maybe the colons at the end are supposed to be dots? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This message comes from the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's helpful to have the label for each individual table since that helps the user know where to look. Maybe we can get rid of the top line. @npatki do you have thoughts on this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this for validating data against the metadata? The idea was to collect and print all the errors at once, which is why there was probably a colon? >>> synthesizer.validate(data)
InvalidDataError: The provided data does not match the metadata
Error: Invalid values found for numerical column 'age': ('a', 'b', 'c', +more)
Error: Key column 'user_id' contains missing values
Error: Foreign key column 'purchaser_id' contains unknown references: ('Unknown', 'USER_999', 'ZZZ', +more). All the values in must reference a primary key. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The top line is for the overall error we throw (afaik, Python only allows you to throw 1 error at a time). The rest of the rows are the message. Breaking up by table makes sense for cleanly summarizing where the errors happen. This could be similar to what we do in >>> synthesizer.validate(data)
InvalidDataError: The provided data does not match the metadata
Table: 'sessions'
Error: Invalid values found for numerical column 'age': ('a', 'b', 'c', +more)
Table: 'users'
Error: Key column 'user_id' contains missing values
...
Relationships:
Error: Foreign key column 'purchaser_id' contains unknown references: ('Unknown', 'USER_999', 'ZZZ', +more). All the values in must reference a primary key. For context, I'm pasting below what we do for InvalidMetadataError: The metadata is not valid
Table: 'users'
Error: Invalid values ("pii") for datetime column "start_date".
Error: A Unique constraint is being applied to column "age". This column is already a key for that table.
Table: 'transactions'
Error: Invalid regex format string "[A-{6}" for text column "transaction_id"
Error: Unknown key value 'ttid'. Keys should be columns that exist in the table.
Error: Invalid increment value (0.5) in a FixedIncrements constraint. Increments must be positive integers.
Relationships:
Error: Relationship between tables ('users', 'transactions') contains an unknown primary key 'userr_id'.
Error: The relationships in the dataset are disjointed. Tables ('sessions') are not connected to any of the other tables. |
||
"Table: 'nesreca'\n" | ||
"Error: Invalid values found for numerical column 'id_nesreca': ['0', '1', '2', " | ||
"'+ 7 more']." | ||
"\n\nTable: 'oseba'\n" | ||
"Error: Invalid values found for numerical column 'upravna_enota': ['0', '1', '2', " | ||
"'+ 7 more']." | ||
"\n\nTable: 'upravna_enota'\n" | ||
"Error: Invalid values found for numerical column 'id_upravna_enota': ['0', '1', '2', " | ||
"'+ 7 more']." | ||
) | ||
with pytest.raises(InvalidDataError, match=error_msg): | ||
instance.validate(data) | ||
|
||
def test_validate_missing_foreign_keys(self): | ||
"""Test that errors are being raised when there are missing foreign keys.""" | ||
# Setup | ||
metadata = get_multi_table_metadata() | ||
data = { | ||
'nesreca': pd.DataFrame({ | ||
'id_nesreca': np.arange(0, 20, 2), | ||
'upravna_enota': np.arange(10), | ||
}), | ||
'oseba': pd.DataFrame({ | ||
'upravna_enota': np.arange(10), | ||
'id_nesreca': np.arange(10), | ||
}), | ||
'upravna_enota': pd.DataFrame({ | ||
'id_upravna_enota': np.arange(10), | ||
}), | ||
} | ||
instance = BaseMultiTableSynthesizer(metadata) | ||
|
||
# Run and Assert | ||
error_msg = re.escape( | ||
'The provided data does not match the metadata:\n' | ||
'Relationships:\n' | ||
"Error: foreign key column 'id_nesreca' contains unknown references: (1, 3, 5, 7, 9). " | ||
'All the values in this column must reference a primary key.' | ||
) | ||
with pytest.raises(InvalidDataError, match=error_msg): | ||
instance.validate(data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why did we add the check for if they're dataframes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before I was raising the error if the check failed on the
SingleTableSynthesizer
, however, since we want to have all the errors raised at once, I have to check this otherwise it will fail as key error or something else.