Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validate to SingleTableMetadata #930

Merged
merged 18 commits into from
Aug 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def _validate_inputs(cls, **kwargs):
f'Invalid values {invalid_vals} are present in {article} {constraint} constraint.'
))

raise MultipleConstraintsErrors(errors)
if errors:
amontanez24 marked this conversation as resolved.
Show resolved Hide resolved
raise MultipleConstraintsErrors(errors)

@classmethod
def _validate_metadata_columns(cls, metadata, **kwargs):
Expand All @@ -135,16 +136,16 @@ def _validate_metadata_columns(cls, metadata, **kwargs):
else:
column_names = kwargs.get('column_names')

missing_columns = set(column_names) - set(metadata._columns)

missing_columns = set(column_names) - set(metadata._columns) - {None}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated. This happens when a column name was not passed/is missing.

if missing_columns:
article = 'An' if cls.__name__ == 'Inequality' else 'A'
raise ConstraintMetadataError(
f'A {cls.__name__} constraint is being applied to invalid column names '
f'{article} {cls.__name__} constraint is being applied to invalid column names '
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated.

f'{missing_columns}. The columns must exist in the table.'
)

@classmethod
def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):
@staticmethod
amontanez24 marked this conversation as resolved.
Show resolved Hide resolved
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
pass

@classmethod
Expand Down
54 changes: 28 additions & 26 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,18 +345,18 @@ def _validate_metadata_columns(cls, metadata, **kwargs):
kwargs['column_names'] = [kwargs.get('high_column_name'), kwargs.get('low_column_name')]
super()._validate_metadata_columns(metadata, **kwargs)

@classmethod
def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):
@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
high = kwargs.get('high_column_name')
low = kwargs.get('low_column_name')
high_sdtype = metadata._columns.get(high, {}).get('sdtype')
low_sdtype = metadata._columns.get(low, {}).get('sdtype')
both_datetime = high_sdtype == low_sdtype == 'datetime'
both_numerical = high_sdtype == low_sdtype == 'numerical'
if not (both_datetime or both_numerical):
if not (both_datetime or both_numerical) and not (high is None or low is None):
raise ConstraintMetadataError(
f'An {cls.__name__} constraint is being applied to mismatched sdtypes '
f'{[high, low]}. Both columns must be either numerical or datetime.'
'An Inequality constraint is being applied to columns with mismatched sdtypes'
f' {[high, low]}. Both columns must be either numerical or datetime.'
)

def __init__(self, low_column_name, high_column_name, strict_boundaries=False):
Expand Down Expand Up @@ -504,26 +504,27 @@ def _validate_inputs(cls, **kwargs):
if errors:
raise MultipleConstraintsErrors(errors)

@classmethod
def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):
@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
column_name = kwargs.get('column_name')
sdtype = metadata._columns.get(column_name, {}).get('sdtype')
val = kwargs.get('value')
if sdtype == 'numerical':
if not isinstance(val, (int, float)):
raise ConstraintMetadataError("'value' must be an int or float")
raise ConstraintMetadataError("'value' must be an int or float.")

elif sdtype == 'datetime':
datetime_format = metadata._columns.get(column_name).get('datetime_format')
matches_format = matches_datetime_format(val, datetime_format)
if not matches_format:
raise ConstraintMetadataError(
"'value' must be a datetime string of the right format"
"'value' must be a datetime string of the right format."
)

else:
raise ConstraintMetadataError(
f'A {cls.__name__} constraint is being applied to mismatched sdtypes. '
'A ScalarInequality constraint is being applied '
'to columns with mismatched sdtypes. '
'Numerical columns must be compared to integer or float values. '
'Datetimes column must be compared to datetime strings.'
)
Expand Down Expand Up @@ -672,13 +673,13 @@ class Positive(ScalarInequality):
zero ``>`` or include it ``>=``.
"""

@classmethod
def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):
@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
column_name = kwargs.get('column_name')
sdtype = metadata._columns.get(column_name, {}).get('sdtype')
if sdtype != 'numerical':
raise ConstraintMetadataError(
f'A {cls.__name__} constraint is being applied to an invalid column '
f'A Positive constraint is being applied to an invalid column '
f"'{column_name}'. This constraint is only defined for numerical columns."
)

Expand All @@ -700,13 +701,13 @@ class Negative(ScalarInequality):
zero ``<`` or include it ``<=``.
"""

@classmethod
def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):
@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
column_name = kwargs.get('column_name')
sdtype = metadata._columns.get(column_name, {}).get('sdtype')
if sdtype != 'numerical':
raise ConstraintMetadataError(
f'A {cls.__name__} constraint is being applied to an invalid column '
f'A Negative constraint is being applied to an invalid column '
f"'{column_name}'. This constraint is only defined for numerical columns."
)

Expand Down Expand Up @@ -742,8 +743,8 @@ def _validate_metadata_columns(cls, metadata, **kwargs):
kwargs['column_names'] = [high, low, middle]
super()._validate_metadata_columns(metadata, **kwargs)

@classmethod
def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):
@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
high = kwargs.get('high_column_name')
low = kwargs.get('low_column_name')
middle = kwargs.get('middle_column_name')
Expand All @@ -752,9 +753,10 @@ def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):
middle_sdtype = metadata._columns.get(middle, {}).get('sdtype')
all_datetime = high_sdtype == low_sdtype == middle_sdtype == 'datetime'
all_numerical = high_sdtype == low_sdtype == middle_sdtype == 'numerical'
if not (all_datetime or all_numerical):
if not (all_datetime or all_numerical) and \
not (high is None or low is None or middle is None):
raise ConstraintMetadataError(
f'A {cls.__name__} constraint is being applied to mismatched sdtypes '
'A Range constraint is being applied to columns with mismatched sdtypes '
f'{[high, middle, low]}. All columns must be either numerical or datetime.'
)

Expand Down Expand Up @@ -926,12 +928,12 @@ def _validate_init_inputs(low_value, high_value):
'represents a datetime.'
)

@classmethod
def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):
@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
column_name = kwargs.get('column_name')
if column_name not in metadata._columns:
raise ConstraintMetadataError(
f'A {cls.__name__} constraint is being applied to invalid column names '
f'A ScalarRange constraint is being applied to invalid column names '
f'({column_name}). The columns must exist in the table.'
)
sdtype = metadata._columns.get(column_name).get('sdtype')
Expand All @@ -955,7 +957,7 @@ def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):

else:
raise ConstraintMetadataError(
f'A {cls.__name__} constraint is being applied to mismatched sdtypes. '
'A ScalarRange constraint is being applied to columns with mismatched sdtypes. '
'Numerical columns must be compared to integer or float values. '
'Datetimes column must be compared to datetime strings.'
)
Expand Down Expand Up @@ -1263,8 +1265,8 @@ def __init__(self, column_names):
self.column_names = column_names
self.constraint_columns = tuple(self.column_names)

@classmethod
def _validate_metadata_specific_to_constraint(cls, metadata, **kwargs):
@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
column_names = kwargs.get('column_names')
keys = set()
if isinstance(metadata._primary_key, tuple):
Expand Down
4 changes: 2 additions & 2 deletions sdv/metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from sdv.metadata import visualization
from sdv.metadata.dataset import Metadata
from sdv.metadata.errors import MetadataError, MetadataNotFittedError
from sdv.metadata.errors import InvalidMetadataError, MetadataNotFittedError
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.metadata.single_table import SingleTableMetadata
from sdv.metadata.table import Table

__all__ = (
'Metadata',
'MetadataError',
'InvalidMetadataError',
'MetadataNotFittedError',
'MultiTableMetadata',
'SingleTableMetadata',
Expand Down
11 changes: 6 additions & 5 deletions sdv/metadata/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from sdv.constraints import Constraint
from sdv.metadata import visualization
from sdv.metadata.errors import MetadataError
from sdv.metadata.errors import InvalidMetadataError

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -551,7 +551,7 @@ def _validate_table(self, table_name, table_meta, table_data=None, errors=None):
on the metadata.

Raises:
MetadataError:
InvalidMetadataError:
If there is any error in the metadata or the data does not
match the metadata description.
"""
Expand Down Expand Up @@ -635,7 +635,7 @@ def validate(self, tables=None):
"""
tables_meta = self._metadata.get('tables')
if not tables_meta:
raise MetadataError('"tables" entry not found in Metadata.')
raise InvalidMetadataError('"tables" entry not found in Metadata.')

if tables and not isinstance(tables, dict):
tables = self.load_tables()
Expand All @@ -654,7 +654,8 @@ def validate(self, tables=None):
self._validate_circular_relationships(table_name, errors=errors)

if errors:
raise MetadataError('Invalid Metadata specification:\n - ' + '\n - '.join(errors))
raise InvalidMetadataError(
'Invalid Metadata specification:\n - ' + '\n - '.join(errors))

def _check_field(self, table, field, exists=False):
"""Validate the existance of the table and existance (or not) of field."""
Expand Down Expand Up @@ -832,7 +833,7 @@ def add_relationship(self, parent, child, foreign_key=None, validate=True):
if validate:
try:
self.validate()
except MetadataError:
except InvalidMetadataError:
self._metadata = metadata_backup
raise

Expand Down
4 changes: 2 additions & 2 deletions sdv/metadata/errors.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Metadata Exceptions."""


class MetadataError(Exception):
class InvalidMetadataError(Exception):
fealho marked this conversation as resolved.
Show resolved Hide resolved
"""Error to raise when Metadata is not valid."""


class MetadataNotFittedError(MetadataError):
class MetadataNotFittedError(InvalidMetadataError):
"""Error to raise when Metadata is used before fitting."""
4 changes: 2 additions & 2 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def visualize(self, show_table_details=True, show_relationship_labels=True,
columns = [f"{name} : {meta.get('sdtype')}" for name, meta in column_dict]
nodes[table_name] = {
'columns': r'\l'.join(columns),
'primary_key': f"Primary key: {table_meta._metadata['primary_key']}"
'primary_key': f'Primary key: {table_meta._primary_key}'
}

else:
Expand All @@ -72,7 +72,7 @@ def visualize(self, show_table_details=True, show_relationship_labels=True,
parent = relationship.get('parent_table_name')
child = relationship.get('child_table_name')
foreign_key = relationship.get('child_foreign_key')
primary_key = self._tables.get(parent)._metadata.get('primary_key')
primary_key = self._tables.get(parent)._primary_key
edge_label = f' {foreign_key} → {primary_key}' if show_relationship_labels else ''
edges.append((parent, child, edge_label))

Expand Down
Loading