diff --git a/.gitignore b/.gitignore index 747cbfc02..241e9f48f 100644 --- a/.gitignore +++ b/.gitignore @@ -104,3 +104,5 @@ ENV/ # Vim .*.swp + +sdv/data/ diff --git a/examples/6. Metadata Validation.ipynb b/examples/6. Metadata Validation.ipynb new file mode 100644 index 000000000..e1981fd24 --- /dev/null +++ b/examples/6. Metadata Validation.ipynb @@ -0,0 +1,244 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sdv import load_demo" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "metadata, tables = load_demo(metadata=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'tables': {'users': {'primary_key': 'user_id',\n", + " 'fields': {'user_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'country': {'type': 'categorical'},\n", + " 'gender': {'type': 'categorical'},\n", + " 'age': {'type': 'numerical', 'subtype': 'integer'}}},\n", + " 'sessions': {'primary_key': 'session_id',\n", + " 'fields': {'session_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'user_id': {'ref': {'field': 'user_id', 'table': 'users'},\n", + " 'type': 'id',\n", + " 'subtype': 'integer'},\n", + " 'device': {'type': 'categorical'},\n", + " 'os': {'type': 'categorical'}}},\n", + " 'transactions': {'primary_key': 'transaction_id',\n", + " 'fields': {'transaction_id': {'type': 'id', 'subtype': 'integer'},\n", + " 'session_id': {'ref': {'field': 'session_id', 'table': 'sessions'},\n", + " 'type': 'id',\n", + " 'subtype': 'integer'},\n", + " 'timestamp': {'type': 'datetime', 'format': '%Y-%m-%d'},\n", + " 'amount': {'type': 'numerical', 'subtype': 'float'},\n", + " 'approved': {'type': 'boolean'}}}}}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata.to_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'users': user_id country gender age\n", + " 0 0 USA M 34\n", + " 1 1 UK F 23\n", + " 2 2 ES None 44\n", + " 3 3 UK M 22\n", + " 4 4 USA F 54\n", + " 5 5 DE M 57\n", + " 6 6 BG F 45\n", + " 7 7 ES None 41\n", + " 8 8 FR F 23\n", + " 9 9 UK None 30,\n", + " 'sessions': session_id user_id device os\n", + " 0 0 0 mobile android\n", + " 1 1 1 tablet ios\n", + " 2 2 1 tablet android\n", + " 3 3 2 mobile android\n", + " 4 4 4 mobile ios\n", + " 5 5 5 mobile android\n", + " 6 6 6 mobile ios\n", + " 7 7 6 tablet ios\n", + " 8 8 6 mobile ios\n", + " 9 9 8 tablet ios,\n", + " 'transactions': transaction_id session_id timestamp amount approved\n", + " 0 0 0 2019-01-01 12:34:32 100.0 True\n", + " 1 1 0 2019-01-01 12:42:21 55.3 True\n", + " 2 2 1 2019-01-07 17:23:11 79.5 True\n", + " 3 3 3 2019-01-10 11:08:57 112.1 False\n", + " 4 4 5 2019-01-10 21:54:08 110.0 False\n", + " 5 5 5 2019-01-11 11:21:20 76.3 True\n", + " 6 6 7 2019-01-22 14:44:10 89.5 True\n", + " 7 7 8 2019-01-23 10:14:09 132.1 False\n", + " 8 8 9 2019-01-27 16:09:17 68.0 True\n", + " 9 9 9 2019-01-29 12:10:48 99.9 True}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tables" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "metadata.validate()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "metadata.validate(tables)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "metadata._metadata['tables']['users']['primary_key'] = 'country'" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "ename": "MetadataError", + "evalue": "id field `user_id` is neither a primary or a foreign key", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mMetadataError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmetadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/Projects/MIT/SDV/sdv/metadata.py\u001b[0m in \u001b[0;36mvalidate\u001b[0;34m(self, tables)\u001b[0m\n\u001b[1;32m 612\u001b[0m \u001b[0mtable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 613\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 614\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_table\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtable_meta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 615\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_circular_relationships\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 616\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_parents\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Projects/MIT/SDV/sdv/metadata.py\u001b[0m in \u001b[0;36m_validate_table\u001b[0;34m(self, table_name, table_meta, table_data)\u001b[0m\n\u001b[1;32m 525\u001b[0m \u001b[0mmatch\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mmetadata\u001b[0m \u001b[0mdescription\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 526\u001b[0m \"\"\"\n\u001b[0;32m--> 527\u001b[0;31m \u001b[0mdtypes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_dtypes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mids\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 528\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 529\u001b[0m \u001b[0;31m# Primary key field exists and its type is 'id'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Projects/MIT/SDV/sdv/metadata.py\u001b[0m in \u001b[0;36mget_dtypes\u001b[0;34m(self, table_name, ids)\u001b[0m\n\u001b[1;32m 364\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mtable_meta\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'primary_key'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mfield\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'ref'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 365\u001b[0m raise MetadataError(\n\u001b[0;32m--> 366\u001b[0;31m 'id field `{}` is neither a primary or a foreign key'.format(name))\n\u001b[0m\u001b[1;32m 367\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 368\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mids\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mfield_type\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'id'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mMetadataError\u001b[0m: id field `user_id` is neither a primary or a foreign key" + ] + } + ], + "source": [ + "metadata.validate()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "metadata._metadata['tables']['users']['primary_key'] = 'user_id'" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "metadata.validate()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "metadata._metadata['tables']['users']['fields']['gender']['type'] = 'numerical'" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "metadata.validate()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "ename": "MetadataError", + "evalue": "Invalid values found in column gender of table users: could not convert string to float: 'M'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mMetadataError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmetadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtables\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/Projects/MIT/SDV/sdv/metadata.py\u001b[0m in \u001b[0;36mvalidate\u001b[0;34m(self, tables)\u001b[0m\n\u001b[1;32m 612\u001b[0m \u001b[0mtable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 613\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 614\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_table\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtable_meta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 615\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_circular_relationships\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 616\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_parents\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Projects/MIT/SDV/sdv/metadata.py\u001b[0m in \u001b[0;36m_validate_table\u001b[0;34m(self, table_name, table_meta, table_data)\u001b[0m\n\u001b[1;32m 546\u001b[0m message = 'Invalid values found in column {} of table {}: {}'.format(\n\u001b[1;32m 547\u001b[0m column, table_name, ve)\n\u001b[0;32m--> 548\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mMetadataError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 549\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 550\u001b[0m \u001b[0;31m# assert all dtypes are in data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mMetadataError\u001b[0m: Invalid values found in column gender of table users: could not convert string to float: 'M'" + ] + } + ], + "source": [ + "metadata.validate(tables)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/sdv/metadata.py b/sdv/metadata.py index 64fd20211..56c05a340 100644 --- a/sdv/metadata.py +++ b/sdv/metadata.py @@ -50,6 +50,10 @@ def _load_csv(root_path, table_meta): return data +class MetadataError(Exception): + pass + + class Metadata: """Dataset Metadata. @@ -351,14 +355,14 @@ def get_dtypes(self, table_name, ids=False): field_subtype = field.get('subtype') dtype = self._DTYPES.get((field_type, field_subtype)) if not dtype: - raise ValueError( + raise MetadataError( 'Invalid type and subtype combination for field {}: ({}, {})'.format( name, field_type, field_subtype) ) if ids and field_type == 'id': if (name != table_meta.get('primary_key')) and not field.get('ref'): - raise ValueError( + raise MetadataError( 'id field `{}` is neither a primary or a foreign key'.format(name)) if ids or (field_type != 'id'): @@ -489,6 +493,128 @@ def reverse_transform(self, table_name, data): return reversed_data + # ################### # + # Metadata Validation # + # ################### # + + def _validate_table(self, table_name, table_meta, table_data=None): + """Validate table metadata. + + Validate the type and subtype combination for each field in ``table_meta``. + If a field has type ``id``, validate that it either is the ``primary_key`` or + has a ``ref`` entry. + + If the table has ``primary_key``, make sure that the corresponding field exists + and its type is ``id``. + + If ``table_data`` is provided, also check that the list of columns corresponds + to the ones indicated in the metadata and that all the dtypes are valid. + + Args: + table_name (str): + Name of the table to validate. + table_meta (dict): + Metadata of the table to validate. + table_data (pandas.DataFrame): + If provided, make sure that the data matches the one described + on the metadata. + + Raises: + MetadataError: + If there is any error in the metadata or the data does not + match the metadata description. + """ + dtypes = self.get_dtypes(table_name, ids=True) + + # Primary key field exists and its type is 'id' + primary_key = table_meta.get('primary_key') + if primary_key: + pk_field = table_meta['fields'].get(primary_key) + + if not pk_field: + raise MetadataError('Primary key is not an existing field.') + + if pk_field['type'] != 'id': + raise MetadataError('Primary key is not of type `id`.') + + if table_data is not None: + for column in table_data: + try: + table_data[column].astype(dtypes[column]) + del dtypes[column] + except ValueError as ve: + message = 'Invalid values found in column {} of table {}: {}'.format( + column, table_name, ve) + raise MetadataError(message) from None + + # assert all dtypes are in data + if dtypes: + raise MetadataError( + 'Missing columns on table {}: {}.'.format(table_name, list(dtypes.keys())) + ) + + def _validate_circular_relationships(self, parent, children=None): + """Validate that there is no circular relatioship in the metadata.""" + if children is None: + children = self.get_children(parent) + + if parent in children: + raise MetadataError('Circular relationship found for table "{}"'.format(parent)) + + for child in children: + self._validate_circular_relationships(parent, self.get_children(child)) + + def _validate_parents(self, table_name): + """Make sure that the table has only one parent.""" + if len(self.get_parents(table_name)) > 1: + raise MetadataError('Table {} has more than one parent.'.format(table_name)) + + def validate(self, tables=None): + """Validate this metadata. + + For each table from in metadata ``tables`` entry: + * Validate the table metadata is correct. + + * If ``tables`` are provided or they have been loaded, check + that all the metadata tables exists in the ``tables`` dictionary. + * Validate the type/subtype combination for each field and + if a field of type ``id`` exists it must be the ``primary_key`` + or must have a ``ref`` entry. + * If ``primary_key`` entry exists, check that it's an existing + field and its type is ``id``. + * If ``tables`` are provided or they have been loaded, check + all the data types for the table correspond to each column and + all the data types exists on the table. + * Validate that there is no circular relatioship in the metadata. + * Check that all the tables have at most one parent. + + Args: + tables (bool, dict): + If a dict of table is passed, validate that the columns and + dtypes match the metadata. If ``True`` is passed, load the + tables from the Metadata instead. If ``None``, omit the data + validation. Defaults to ``None``. + """ + tables_meta = self._metadata.get('tables') + if not tables_meta: + raise MetadataError('"tables" entry not found in Metadata.') + + if tables and not isinstance(tables, dict): + tables = self.load_tables() + + for table_name, table_meta in tables_meta.items(): + if tables: + table = tables.get(table_name) + if table is None: + raise MetadataError('Table `{}` not found in tables'.format(table_name)) + + else: + table = None + + self._validate_table(table_name, table_meta, table) + self._validate_circular_relationships(table_name) + self._validate_parents(table_name) + def _check_field(self, table, field, exists=False): """Validate the existance of the table and existance (or not) of field.""" table_fields = self.get_fields(table) @@ -498,6 +624,10 @@ def _check_field(self, table, field, exists=False): if not exists and (field in table_fields): raise ValueError('Field "{}" already exists in table "{}"'.format(field, table)) + # ################# # + # Metadata Creation # + # ################# # + def add_field(self, table, field, field_type, field_subtype=None, properties=None): """Add a new field to the indicated table. @@ -580,17 +710,6 @@ def set_primary_key(self, table, field): } table_meta['primary_key'] = field - def _validate_circular_relationships(self, parent, children=None): - """Validate that there is no circular relatioship in the metadata.""" - if children is None: - children = self.get_children(parent) - - if parent in children: - raise ValueError('Circular relationship found for table "{}"'.format(parent)) - - for child in children: - self._validate_circular_relationships(parent, self.get_children(child)) - def add_relationship(self, parent, child, foreign_key=None): """Add a new relationship between the parent and child tables. @@ -762,6 +881,10 @@ def add_table(self, name, data=None, fields=None, fields_metadata=None, del self._metadata['tables'][name] raise + # ###################### # + # Metadata Serialization # + # ###################### # + def to_dict(self): """Get a dict representation of this metadata. diff --git a/sdv/sdv.py b/sdv/sdv.py index cdd4b6a28..3905d348d 100644 --- a/sdv/sdv.py +++ b/sdv/sdv.py @@ -40,12 +40,6 @@ def __init__(self, model=DEFAULT_MODEL, model_kwargs=None): else: self.model_kwargs = model_kwargs - def _validate_dataset_structure(self): - """Make sure that all the tables have at most one parent.""" - for table in self.metadata.get_tables(): - if len(self.metadata.get_parents(table)) > 1: - raise ValueError('Some tables have multiple parents, which is not supported yet.') - def fit(self, metadata, tables=None, root_path=None): """Fit this SDV instance to the dataset data. @@ -67,7 +61,7 @@ def fit(self, metadata, tables=None, root_path=None): else: self.metadata = Metadata(metadata, root_path) - self._validate_dataset_structure() + self.metadata.validate(tables) self.modeler = Modeler(self.metadata, self.model, self.model_kwargs) self.modeler.model_database(tables) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 03f8fbac2..f4605a697 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -1,10 +1,10 @@ from unittest import TestCase -from unittest.mock import Mock, call, patch +from unittest.mock import MagicMock, Mock, call, patch import pandas as pd import pytest -from sdv.metadata import Metadata, _load_csv, _parse_dtypes, _read_csv_dtypes +from sdv.metadata import Metadata, MetadataError, _load_csv, _parse_dtypes, _read_csv_dtypes def test__read_csv_dtypes(): @@ -217,6 +217,28 @@ def test__dict_metadata_dict(self): } assert result == expected + def test__validate_parents_no_error(self): + """Test that any error is raised with a supported structure""" + # Setup + mock = MagicMock(spec=Metadata) + mock.get_parents.return_value = [] + + # Run + Metadata._validate_parents(mock, 'demo') + + # Asserts + mock.get_parents.assert_called_once_with('demo') + + def test__validate_parents_raise_error(self): + """Test that a ValueError is raised because the bad structure""" + # Setup + mock = MagicMock(spec=Metadata) + mock.get_parents.return_value = ['foo', 'bar'] + + # Run + with pytest.raises(MetadataError): + Metadata._validate_parents(mock, 'demo') + @patch('sdv.metadata.Metadata._analyze_relationships') @patch('sdv.metadata.Metadata._dict_metadata') def test___init__default_metadata_dict(self, mock_meta, mock_relationships): @@ -367,7 +389,7 @@ def test_get_dtypes_error_invalid_type(self): metadata._DTYPES = Metadata._DTYPES # Run - with pytest.raises(ValueError): + with pytest.raises(MetadataError): Metadata.get_dtypes(metadata, 'test') def test_get_dtypes_error_id(self): @@ -383,7 +405,7 @@ def test_get_dtypes_error_id(self): metadata._DTYPES = Metadata._DTYPES # Run - with pytest.raises(ValueError): + with pytest.raises(MetadataError): Metadata.get_dtypes(metadata, 'test', ids=True) def test_get_dtypes_error_subtype_numerical(self): @@ -399,7 +421,7 @@ def test_get_dtypes_error_subtype_numerical(self): metadata._DTYPES = Metadata._DTYPES # Run - with pytest.raises(ValueError): + with pytest.raises(MetadataError): Metadata.get_dtypes(metadata, 'test') def test_get_dtypes_error_subtype_id(self): @@ -415,7 +437,7 @@ def test_get_dtypes_error_subtype_id(self): metadata._DTYPES = Metadata._DTYPES # Run - with pytest.raises(ValueError): + with pytest.raises(MetadataError): Metadata.get_dtypes(metadata, 'test', ids=True) def test__get_pii_fields(self): diff --git a/tests/test_sdv.py b/tests/test_sdv.py index e6844a4d1..f1373129e 100644 --- a/tests/test_sdv.py +++ b/tests/test_sdv.py @@ -51,30 +51,6 @@ def test____init__users_params(self): assert sdv.model == 'test' assert sdv.model_kwargs == {'a': 2} - def test__validate_dataset_structure_no_error(self): - """Test that any error is raised with a supported structure""" - # Setup - sdv = Mock() - sdv.metadata.get_tables.return_value = ['foo', 'bar', 'tar'] - sdv.metadata.get_parents.side_effect = [[], ['foo'], ['bar']] - - # Run - SDV._validate_dataset_structure(sdv) - - # Asserts - assert sdv.metadata.get_parents.call_count == 3 - - def test__validate_dataset_structure_raise_error(self): - """Test that a ValueError is raised because the bad structure""" - # Setup - sdv = Mock() - sdv.metadata.get_tables.return_value = ['foo', 'bar', 'tar'] - sdv.metadata.get_parents.side_effect = [[], [], ['foo', 'bar']] - - # Run - with pytest.raises(ValueError): - SDV._validate_dataset_structure(sdv) - def test_sample_fitted(self): """Check that the sample is called.""" # Sample