Skip to content

Commit

Permalink
Metadata validation improvements (#354)
Browse files Browse the repository at this point in the history
* Improve metadata visualization

* Find the list of all errors when validating metadata

* Fix tests

* Fix tests

* Remove unused import
  • Loading branch information
csala authored Mar 23, 2021
1 parent 4143eb1 commit 1f0077a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 33 deletions.
65 changes: 40 additions & 25 deletions sdv/metadata/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def load_tables(self, tables=None):
for table_name in tables or self.get_tables()
}

def get_dtypes(self, table_name, ids=False):
def get_dtypes(self, table_name, ids=False, errors=None):
"""Get a ``dict`` with the ``dtypes`` for each field of a given table.
Args:
Expand All @@ -378,26 +378,27 @@ def get_dtypes(self, table_name, ids=False):
If a field has an invalid type or subtype or if the table does not
exist in this metadata.
"""
errors = [] if errors is None else errors
dtypes = dict()
table_meta = self.get_table_meta(table_name)
for name, field in table_meta['fields'].items():
field_type = field['type']
field_subtype = field.get('subtype')
dtype = self._DTYPES.get((field_type, field_subtype))
if not dtype:
raise MetadataError(
errors.append(
'Invalid type and subtype combination for field {}: ({}, {})'.format(
name, field_type, field_subtype)
)
else:
if ids and field_type == 'id':
if (name != table_meta.get('primary_key')) and not field.get('ref'):
for child_table in self.get_children(table_name):
if name in self.get_foreign_keys(table_name, child_table):
break

if ids and field_type == 'id':
if (name != table_meta.get('primary_key')) and not field.get('ref'):
for child_table in self.get_children(table_name):
if name in self.get_foreign_keys(table_name, child_table):
break

if ids or (field_type != 'id'):
dtypes[name] = dtype
if ids or (field_type != 'id'):
dtypes[name] = dtype

return dtypes

Expand Down Expand Up @@ -528,7 +529,7 @@ def reverse_transform(self, table_name, data):
# Metadata Validation #
# ################### #

def _validate_table(self, table_name, table_meta, table_data=None):
def _validate_table(self, table_name, table_meta, table_data=None, errors=None):
"""Validate table metadata.
Validate the type and subtype combination for each field in ``table_meta``.
Expand All @@ -555,18 +556,20 @@ def _validate_table(self, table_name, table_meta, table_data=None):
If there is any error in the metadata or the data does not
match the metadata description.
"""
dtypes = self.get_dtypes(table_name, ids=True)
errors = [] if errors is None else errors
dtypes = self.get_dtypes(table_name, ids=True, errors=errors)

# Primary key field exists and its type is 'id'
primary_key = table_meta.get('primary_key')
if primary_key:
pk_field = table_meta['fields'].get(primary_key)

if not pk_field:
raise MetadataError('Primary key is not an existing field.')

if pk_field['type'] != 'id':
raise MetadataError('Primary key is not of type `id`.')
errors.append(
f'Invalid primary key: "{primary_key}" not found in table "{table_name}"')
elif pk_field['type'] != 'id':
errors.append(
f'Primary key "{primary_key}" of table "{table_name}" not of type "id"')

if table_data is not None:
for column in table_data:
Expand All @@ -575,28 +578,36 @@ def _validate_table(self, table_name, table_meta, table_data=None):
table_data[column].dropna().astype(dtype)
except KeyError:
message = 'Unexpected column in table `{}`: `{}`'.format(table_name, column)
raise MetadataError(message) from None
errors.append(message)
except ValueError as ve:
message = 'Invalid values found in column `{}` of table `{}`: `{}`'.format(
column, table_name, ve)
raise MetadataError(message) from None
errors.append(message)

# assert all dtypes are in data
if dtypes:
raise MetadataError(
errors.append(
'Missing columns on table {}: {}.'.format(table_name, list(dtypes.keys()))
)

def _validate_circular_relationships(self, parent, children=None):
def _validate_circular_relationships(self, parent, children=None, errors=None, parents=None):
"""Validate that there is no circular relatioship in the metadata."""
errors = [] if errors is None else errors
parents = set() if parents is None else parents
if children is None:
children = self.get_children(parent)

if parent in children:
raise MetadataError('Circular relationship found for table "{}"'.format(parent))
error = 'Circular relationship found for table "{}"'.format(parent)
errors.append(error)

for child in children:
self._validate_circular_relationships(parent, self.get_children(child))
if child in parents:
break

parents.add(child)
self._validate_circular_relationships(
parent, self.get_children(child), errors, parents)

def validate(self, tables=None):
"""Validate this metadata.
Expand Down Expand Up @@ -630,17 +641,21 @@ def validate(self, tables=None):
if tables and not isinstance(tables, dict):
tables = self.load_tables()

errors = []
for table_name, table_meta in tables_meta.items():
if tables:
table = tables.get(table_name)
if table is None:
raise MetadataError('Table `{}` not found in tables'.format(table_name))
errors.append('Table `{}` not found in tables'.format(table_name))

else:
table = None

self._validate_table(table_name, table_meta, table)
self._validate_circular_relationships(table_name)
self._validate_table(table_name, table_meta, table, errors)
self._validate_circular_relationships(table_name, errors=errors)

if errors:
raise MetadataError('Invalid Metadata specification:\n - ' + '\n - '.join(errors))

def _check_field(self, table, field, exists=False):
"""Validate the existance of the table and existance (or not) of field."""
Expand Down
21 changes: 13 additions & 8 deletions tests/unit/metadata/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import pandas as pd
import pytest

from sdv.metadata.dataset import (
Metadata, MetadataError, _load_csv, _parse_dtypes, _read_csv_dtypes)
from sdv.metadata.dataset import Metadata, _load_csv, _parse_dtypes, _read_csv_dtypes


def test__read_csv_dtypes():
Expand Down Expand Up @@ -369,8 +368,10 @@ def test_get_dtypes_error_invalid_type(self):
metadata._DTYPES = Metadata._DTYPES

# Run
with pytest.raises(MetadataError):
Metadata.get_dtypes(metadata, 'test')
errors = []
Metadata.get_dtypes(metadata, 'test', errors=errors)

assert len(errors) == 1

def test_get_dtypes_error_subtype_numerical(self):
"""Test get data types with an invalid numerical subtype."""
Expand All @@ -385,8 +386,10 @@ def test_get_dtypes_error_subtype_numerical(self):
metadata._DTYPES = Metadata._DTYPES

# Run
with pytest.raises(MetadataError):
Metadata.get_dtypes(metadata, 'test')
errors = []
Metadata.get_dtypes(metadata, 'test', errors=errors)

assert len(errors) == 1

def test_get_dtypes_error_subtype_id(self):
"""Test get data types with an invalid id subtype."""
Expand All @@ -401,8 +404,10 @@ def test_get_dtypes_error_subtype_id(self):
metadata._DTYPES = Metadata._DTYPES

# Run
with pytest.raises(MetadataError):
Metadata.get_dtypes(metadata, 'test', ids=True)
errors = []
Metadata.get_dtypes(metadata, 'test', ids=True, errors=errors)

assert len(errors) == 1

def test__get_pii_fields(self):
"""Test get pii fields"""
Expand Down

0 comments on commit 1f0077a

Please sign in to comment.