Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add verbosity to drop_unknown_references #1854

Merged
merged 6 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions sdv/utils/poc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Utility functions."""
import sys

import pandas as pd

from sdv._utils import (
_get_relationship_for_child, _get_rows_to_drop, _validate_foreign_keys_not_null)
from sdv.errors import InvalidDataError, SynthesizerInputError


def drop_unknown_references(metadata, data, drop_missing_values=True):
def drop_unknown_references(metadata, data, drop_missing_values=True, verbose=True):
"""Drop rows with unknown foreign keys.

Args:
Expand All @@ -17,22 +21,38 @@ def drop_unknown_references(metadata, data, drop_missing_values=True):
Boolean describing whether or not to also drop foreign keys with missing values
If True, drop rows with missing values in the foreign keys.
Defaults to True.
verbose (bool):
If True, print information about the rows that are dropped.
Defaults to True.

Returns:
dict:
Dictionary with the dataframes ensuring referential integrity.
"""
success_message = 'Success! All foreign keys have referential integrity.'
table_names = sorted(metadata.tables)
summary_table = pd.DataFrame({
'Table Name': table_names,
'# Rows (Original)': [len(data[table]) for table in table_names],
'# Invalid Rows': [0] * len(table_names),
'# Rows (New)': [len(data[table]) for table in table_names]
})
metadata.validate()
try:
metadata.validate_data(data)
if drop_missing_values:
_validate_foreign_keys_not_null(metadata, data)

if verbose:
sys.stdout.write(
'\n'.join([success_message, '', summary_table.to_string(index=False)])
)

return data
except (InvalidDataError, SynthesizerInputError):
result = data.copy()
table_to_idx_to_drop = _get_rows_to_drop(metadata, result)
for table in metadata.tables:
for table in table_names:
idx_to_drop = table_to_idx_to_drop[table]
result[table] = result[table].drop(idx_to_drop)
if drop_missing_values:
Expand All @@ -47,4 +67,13 @@ def drop_unknown_references(metadata, data, drop_missing_values=True):
'Try providing different data for this table.'
])

if verbose:
summary_table['# Invalid Rows'] = [
len(data[table]) - len(result[table]) for table in table_names
]
summary_table['# Rows (New)'] = [len(result[table]) for table in table_names]
sys.stdout.write('\n'.join([
success_message, '', summary_table.to_string(index=False)
]))

return result
34 changes: 30 additions & 4 deletions tests/integration/utils/test_poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def data():
}


def test_drop_unknown_references(metadata, data):
def test_drop_unknown_references(metadata, data, capsys):
"""Test ``drop_unknown_references`` end to end."""
# Run
expected_message = re.escape(
Expand All @@ -75,28 +75,44 @@ def test_drop_unknown_references(metadata, data):

cleaned_data = drop_unknown_references(metadata, data)
metadata.validate_data(cleaned_data)
captured = capsys.readouterr()

# Assert
pd.testing.assert_frame_equal(cleaned_data['parent'], data['parent'])
pd.testing.assert_frame_equal(cleaned_data['child'], data['child'].iloc[:4])
assert len(cleaned_data['child']) == 4
expected_output = (
'Success! All foreign keys have referential integrity.\n\n'
'Table Name # Rows (Original) # Invalid Rows # Rows (New)\n'
' child 5 1 4\n'
' parent 5 0 5'
)
assert captured.out.strip() == expected_output


def test_drop_unknown_references_valid_data(metadata, data):
def test_drop_unknown_references_valid_data(metadata, data, capsys):
"""Test ``drop_unknown_references`` when data has referential integrity."""
# Setup
data = deepcopy(data)
data['child'].loc[4, 'parent_id'] = 2

# Run
result = drop_unknown_references(metadata, data)
captured = capsys.readouterr()

# Assert
pd.testing.assert_frame_equal(result['parent'], data['parent'])
pd.testing.assert_frame_equal(result['child'], data['child'])
expected_message = (
'Success! All foreign keys have referential integrity.\n\n'
'Table Name # Rows (Original) # Invalid Rows # Rows (New)\n'
' child 5 0 5\n'
' parent 5 0 5'
)
assert captured.out.strip() == expected_message
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather we not have a special case message for this case. In the spec, we said we should print out for every table all the time.

Even if nothing was dropped, I think it's useful reassurance to the user that we checked all tables. Plus, it simplifies our logic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, done in 7a80b0a



def test_drop_unknown_references_drop_missing_values(metadata, data):
def test_drop_unknown_references_drop_missing_values(metadata, data, capsys):
"""Test ``drop_unknown_references`` when there is missing values in the foreign keys."""
# Setup
data = deepcopy(data)
Expand All @@ -105,11 +121,19 @@ def test_drop_unknown_references_drop_missing_values(metadata, data):
# Run
cleaned_data = drop_unknown_references(metadata, data)
metadata.validate_data(cleaned_data)
captured = capsys.readouterr()

# Assert
pd.testing.assert_frame_equal(cleaned_data['parent'], data['parent'])
pd.testing.assert_frame_equal(cleaned_data['child'], data['child'].iloc[:4])
assert len(cleaned_data['child']) == 4
expected_output = (
'Success! All foreign keys have referential integrity.\n\n'
'Table Name # Rows (Original) # Invalid Rows # Rows (New)\n'
' child 5 1 4\n'
' parent 5 0 5'
)
assert captured.out.strip() == expected_output


def test_drop_unknown_references_not_drop_missing_values(metadata, data):
Expand All @@ -118,7 +142,9 @@ def test_drop_unknown_references_not_drop_missing_values(metadata, data):
data['child'].loc[3, 'parent_id'] = np.nan

# Run
cleaned_data = drop_unknown_references(metadata, data, drop_missing_values=False)
cleaned_data = drop_unknown_references(
metadata, data, drop_missing_values=False, verbose=False
)

# Assert
pd.testing.assert_frame_equal(cleaned_data['parent'], data['parent'])
Expand Down
29 changes: 25 additions & 4 deletions tests/unit/utils/test_poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from sdv.utils.poc import drop_unknown_references


@patch('sys.stdout.write')
@patch('sdv.utils.poc._get_rows_to_drop')
def test_drop_unknown_references(mock_get_rows_to_drop):
def test_drop_unknown_references(mock_get_rows_to_drop, mock_stdout_write):
"""Test ``drop_unknown_references``."""
# Setup
relationships = [
Expand Down Expand Up @@ -65,6 +66,15 @@ def test_drop_unknown_references(mock_get_rows_to_drop):
result = drop_unknown_references(metadata, data)

# Assert
expected_pattern = re.compile(
r'Success! All foreign keys have referential integrity\.\s*'
r'Table Name\s*#\s*Rows \(Original\)\s*#\s*Invalid Rows\s*#\s*Rows \(New\)\s*'
r'child\s*5\s*1\s*4\s*'
r'grandchild\s*5\s*3\s*2\s*'
r'parent\s*5\s*0\s*5'
)
output = mock_stdout_write.call_args[0][0]
assert expected_pattern.match(output)
metadata.validate.assert_called_once()
metadata.validate_data.assert_called_once_with(data)
mock_get_rows_to_drop.assert_called_once()
Expand All @@ -88,13 +98,15 @@ def test_drop_unknown_references(mock_get_rows_to_drop):
pd.testing.assert_frame_equal(table, expected_result[table_name])


def test_drop_unknown_references_valid_data_mock():
@patch('sys.stdout.write')
def test_drop_unknown_references_valid_data_mock(mock_stdout_write):
"""Test ``drop_unknown_references`` when data has referential integrity."""
# Setup
metadata = Mock()
metadata._get_all_foreign_keys.side_effect = [
[], ['parent_foreign_key'], ['child_foreign_key', 'parent_foreign_key']
]
metadata.tables = {'parent', 'child', 'grandchild'}
data = {
'parent': pd.DataFrame({
'id_parent': [0, 1, 2, 3, 4],
Expand All @@ -116,6 +128,15 @@ def test_drop_unknown_references_valid_data_mock():
result = drop_unknown_references(metadata, data)

# Assert
expected_pattern = re.compile(
r'Success! All foreign keys have referential integrity\.\s*'
r'Table Name\s*#\s*Rows \(Original\)\s*#\s*Invalid Rows\s*#\s*Rows \(New\)\s*'
r'child\s*5\s*0\s*5\s*'
r'grandchild\s*5\s*0\s*5\s*'
r'parent\s*5\s*0\s*5'
)
output = mock_stdout_write.call_args[0][0]
assert expected_pattern.match(output)
metadata.validate.assert_called_once()
metadata.validate_data.assert_called_once_with(data)
for table_name, table in result.items():
Expand Down Expand Up @@ -175,7 +196,7 @@ def test_drop_unknown_references_with_nan(mock_validate_foreign_keys, mock_get_r
})

# Run
result = drop_unknown_references(metadata, data)
result = drop_unknown_references(metadata, data, verbose=False)

# Assert
metadata.validate.assert_called_once()
Expand Down Expand Up @@ -255,7 +276,7 @@ def test_drop_unknown_references_drop_missing_values_false(mock_get_rows_to_drop
})

# Run
result = drop_unknown_references(metadata, data, drop_missing_values=False)
result = drop_unknown_references(metadata, data, drop_missing_values=False, verbose=False)

# Assert
mock_get_rows_to_drop.assert_called_once()
Expand Down
Loading