Skip to content

Commit

Permalink
Add validate to SingleTableMetadata (#930)
Browse files Browse the repository at this point in the history
* Add key validation methods (the methods themselves still gotta be revised)

* Combined errors implementation

* Finish constraints

* Add silly constraint solution

* Add try/catches for validation + add unit and integration tests

* Change MetadataError to InvalidMetadataError

* Simplify tests to one element + fix lint

* Remove _metadata

* Change way _metadata implemented in to_dict

* Add constraint validation + integration test + change _constraints to be list of tuple

* Fix error message for validate + add error integration test

* Fix all constraint related error messages

* Add helper _append_error to validate + fix constraint related errors

* Add constraint unit test

* Fix lint

* change _metadata[pk] to pk

* Add tests requested + addresses feedback

* Address minor feedback
  • Loading branch information
fealho authored Aug 4, 2022
1 parent 2fde2d0 commit aabcaad
Show file tree
Hide file tree
Showing 12 changed files with 560 additions and 195 deletions.
13 changes: 7 additions & 6 deletions sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def _validate_inputs(cls, **kwargs):
f'Invalid values {invalid_vals} are present in {article} {constraint} constraint.'
))

raise MultipleConstraintsErrors(errors)
if errors:
raise MultipleConstraintsErrors(errors)

@classmethod
def _validate_metadata_columns(cls, metadata, **kwargs):
Expand All @@ -135,16 +136,16 @@ def _validate_metadata_columns(cls, metadata, **kwargs):
else:
column_names = kwargs.get('column_names')

missing_columns = set(column_names) - set(metadata._columns)

missing_columns = set(column_names) - set(metadata._columns) - {None}
if missing_columns:
article = 'An' if cls.__name__ == 'Inequality' else 'A'
raise ConstraintMetadataError(
f'A {cls.__name__} constraint is being applied to invalid column names '
f'{article} {cls.__name__} constraint is being applied to invalid column names '
f'{missing_columns}. The columns must exist in the table.'
)

@classmethod
def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):
@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
pass

@classmethod
Expand Down
54 changes: 28 additions & 26 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,18 +345,18 @@ def _validate_metadata_columns(cls, metadata, **kwargs):
kwargs['column_names'] = [kwargs.get('high_column_name'), kwargs.get('low_column_name')]
super()._validate_metadata_columns(metadata, **kwargs)

@classmethod
def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):
@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
high = kwargs.get('high_column_name')
low = kwargs.get('low_column_name')
high_sdtype = metadata._columns.get(high, {}).get('sdtype')
low_sdtype = metadata._columns.get(low, {}).get('sdtype')
both_datetime = high_sdtype == low_sdtype == 'datetime'
both_numerical = high_sdtype == low_sdtype == 'numerical'
if not (both_datetime or both_numerical):
if not (both_datetime or both_numerical) and not (high is None or low is None):
raise ConstraintMetadataError(
f'An {cls.__name__} constraint is being applied to mismatched sdtypes '
f'{[high, low]}. Both columns must be either numerical or datetime.'
'An Inequality constraint is being applied to columns with mismatched sdtypes'
f' {[high, low]}. Both columns must be either numerical or datetime.'
)

def __init__(self, low_column_name, high_column_name, strict_boundaries=False):
Expand Down Expand Up @@ -504,26 +504,27 @@ def _validate_inputs(cls, **kwargs):
if errors:
raise MultipleConstraintsErrors(errors)

@classmethod
def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):
@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
column_name = kwargs.get('column_name')
sdtype = metadata._columns.get(column_name, {}).get('sdtype')
val = kwargs.get('value')
if sdtype == 'numerical':
if not isinstance(val, (int, float)):
raise ConstraintMetadataError("'value' must be an int or float")
raise ConstraintMetadataError("'value' must be an int or float.")

elif sdtype == 'datetime':
datetime_format = metadata._columns.get(column_name).get('datetime_format')
matches_format = matches_datetime_format(val, datetime_format)
if not matches_format:
raise ConstraintMetadataError(
"'value' must be a datetime string of the right format"
"'value' must be a datetime string of the right format."
)

else:
raise ConstraintMetadataError(
f'A {cls.__name__} constraint is being applied to mismatched sdtypes. '
'A ScalarInequality constraint is being applied '
'to columns with mismatched sdtypes. '
'Numerical columns must be compared to integer or float values. '
'Datetimes column must be compared to datetime strings.'
)
Expand Down Expand Up @@ -672,13 +673,13 @@ class Positive(ScalarInequality):
zero ``>`` or include it ``>=``.
"""

@classmethod
def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):
@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
column_name = kwargs.get('column_name')
sdtype = metadata._columns.get(column_name, {}).get('sdtype')
if sdtype != 'numerical':
raise ConstraintMetadataError(
f'A {cls.__name__} constraint is being applied to an invalid column '
f'A Positive constraint is being applied to an invalid column '
f"'{column_name}'. This constraint is only defined for numerical columns."
)

Expand All @@ -700,13 +701,13 @@ class Negative(ScalarInequality):
zero ``<`` or include it ``<=``.
"""

@classmethod
def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):
@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
column_name = kwargs.get('column_name')
sdtype = metadata._columns.get(column_name, {}).get('sdtype')
if sdtype != 'numerical':
raise ConstraintMetadataError(
f'A {cls.__name__} constraint is being applied to an invalid column '
f'A Negative constraint is being applied to an invalid column '
f"'{column_name}'. This constraint is only defined for numerical columns."
)

Expand Down Expand Up @@ -742,8 +743,8 @@ def _validate_metadata_columns(cls, metadata, **kwargs):
kwargs['column_names'] = [high, low, middle]
super()._validate_metadata_columns(metadata, **kwargs)

@classmethod
def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):
@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
high = kwargs.get('high_column_name')
low = kwargs.get('low_column_name')
middle = kwargs.get('middle_column_name')
Expand All @@ -752,9 +753,10 @@ def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):
middle_sdtype = metadata._columns.get(middle, {}).get('sdtype')
all_datetime = high_sdtype == low_sdtype == middle_sdtype == 'datetime'
all_numerical = high_sdtype == low_sdtype == middle_sdtype == 'numerical'
if not (all_datetime or all_numerical):
if not (all_datetime or all_numerical) and \
not (high is None or low is None or middle is None):
raise ConstraintMetadataError(
f'A {cls.__name__} constraint is being applied to mismatched sdtypes '
'A Range constraint is being applied to columns with mismatched sdtypes '
f'{[high, middle, low]}. All columns must be either numerical or datetime.'
)

Expand Down Expand Up @@ -926,12 +928,12 @@ def _validate_init_inputs(low_value, high_value):
'represents a datetime.'
)

@classmethod
def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):
@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
column_name = kwargs.get('column_name')
if column_name not in metadata._columns:
raise ConstraintMetadataError(
f'A {cls.__name__} constraint is being applied to invalid column names '
f'A ScalarRange constraint is being applied to invalid column names '
f'({column_name}). The columns must exist in the table.'
)
sdtype = metadata._columns.get(column_name).get('sdtype')
Expand All @@ -955,7 +957,7 @@ def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):

else:
raise ConstraintMetadataError(
f'A {cls.__name__} constraint is being applied to mismatched sdtypes. '
'A ScalarRange constraint is being applied to columns with mismatched sdtypes. '
'Numerical columns must be compared to integer or float values. '
'Datetimes column must be compared to datetime strings.'
)
Expand Down Expand Up @@ -1263,8 +1265,8 @@ def __init__(self, column_names):
self.column_names = column_names
self.constraint_columns = tuple(self.column_names)

@classmethod
def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):
@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
column_names = kwargs.get('column_names')
keys = set()
if isinstance(metadata._primary_key, tuple):
Expand Down
4 changes: 2 additions & 2 deletions sdv/metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from sdv.metadata import visualization
from sdv.metadata.dataset import Metadata
from sdv.metadata.errors import MetadataError, MetadataNotFittedError
from sdv.metadata.errors import InvalidMetadataError, MetadataNotFittedError
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.metadata.single_table import SingleTableMetadata
from sdv.metadata.table import Table

__all__ = (
'Metadata',
'MetadataError',
'InvalidMetadataError',
'MetadataNotFittedError',
'MultiTableMetadata',
'SingleTableMetadata',
Expand Down
11 changes: 6 additions & 5 deletions sdv/metadata/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from sdv.constraints import Constraint
from sdv.metadata import visualization
from sdv.metadata.errors import MetadataError
from sdv.metadata.errors import InvalidMetadataError

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -551,7 +551,7 @@ def _validate_table(self, table_name, table_meta, table_data=None, errors=None):
on the metadata.
Raises:
MetadataError:
InvalidMetadataError:
If there is any error in the metadata or the data does not
match the metadata description.
"""
Expand Down Expand Up @@ -635,7 +635,7 @@ def validate(self, tables=None):
"""
tables_meta = self._metadata.get('tables')
if not tables_meta:
raise MetadataError('"tables" entry not found in Metadata.')
raise InvalidMetadataError('"tables" entry not found in Metadata.')

if tables and not isinstance(tables, dict):
tables = self.load_tables()
Expand All @@ -654,7 +654,8 @@ def validate(self, tables=None):
self._validate_circular_relationships(table_name, errors=errors)

if errors:
raise MetadataError('Invalid Metadata specification:\n - ' + '\n - '.join(errors))
raise InvalidMetadataError(
'Invalid Metadata specification:\n - ' + '\n - '.join(errors))

def _check_field(self, table, field, exists=False):
"""Validate the existance of the table and existance (or not) of field."""
Expand Down Expand Up @@ -832,7 +833,7 @@ def add_relationship(self, parent, child, foreign_key=None, validate=True):
if validate:
try:
self.validate()
except MetadataError:
except InvalidMetadataError:
self._metadata = metadata_backup
raise

Expand Down
4 changes: 2 additions & 2 deletions sdv/metadata/errors.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Metadata Exceptions."""


class MetadataError(Exception):
class InvalidMetadataError(Exception):
"""Error to raise when Metadata is not valid."""


class MetadataNotFittedError(MetadataError):
class MetadataNotFittedError(InvalidMetadataError):
"""Error to raise when Metadata is used before fitting."""
4 changes: 2 additions & 2 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def visualize(self, show_table_details=True, show_relationship_labels=True,
columns = [f"{name} : {meta.get('sdtype')}" for name, meta in column_dict]
nodes[table_name] = {
'columns': r'\l'.join(columns),
'primary_key': f"Primary key: {table_meta._metadata['primary_key']}"
'primary_key': f'Primary key: {table_meta._primary_key}'
}

else:
Expand All @@ -72,7 +72,7 @@ def visualize(self, show_table_details=True, show_relationship_labels=True,
parent = relationship.get('parent_table_name')
child = relationship.get('child_table_name')
foreign_key = relationship.get('child_foreign_key')
primary_key = self._tables.get(parent)._metadata.get('primary_key')
primary_key = self._tables.get(parent)._primary_key
edge_label = f' {foreign_key}{primary_key}' if show_relationship_labels else ''
edges.append((parent, child, edge_label))

Expand Down
Loading

0 comments on commit aabcaad

Please sign in to comment.