Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metadata improvements #1610

Merged
merged 7 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,15 @@ def _create_config(self, data, columns_created_by_constraints):
)
sdtypes[column] = 'pii'

elif sdtype == 'unknown':
transformers[column] = AnonymizedFaker(
function_name='bothify',
)
transformers[column].function_kwargs = {
'text': 'sdv-pii-?????',
'letters': '0123456789abcdefghijklmnopqrstuvwxyz'
}

elif pii:
enforce_uniqueness = bool(column in self._keys)
transformers[column] = self.create_anonymized_transformer(
Expand Down
147 changes: 112 additions & 35 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,15 @@ def _validate_relationship(self, parent_table_name, child_table_name,
child_foreign_key
)

def _get_parent_map(self):
parent_map = defaultdict(set)
for relation in self.relationships:
parent_name = relation['parent_table_name']
child_name = relation['child_table_name']
parent_map[child_name].add(parent_name)

return parent_map

def _get_child_map(self):
child_map = defaultdict(set)
for relation in self.relationships:
Expand Down Expand Up @@ -256,6 +265,32 @@ def add_relationship(self, parent_table_name, child_table_name,
'child_foreign_key': deepcopy(child_foreign_key),
})

def remove_relationship(self, parent_table_name, child_table_name):
"""Remove the relationship between two tables.

Args:
parent_table_name (str):
The name of the parent table.
child_table_name (str):
The name of the child table.
"""
relationships_to_remove = []
for relation in self.relationships:
if (relation['parent_table_name'] == parent_table_name and
relation['child_table_name'] == child_table_name):
relationships_to_remove.append(relation)

if not relationships_to_remove:
warning_msg = (
f"No existing relationships found between parent table '{parent_table_name}' and "
f"child table '{child_table_name}'."
)
warnings.warn(warning_msg)

else:
for relation in relationships_to_remove:
self.relationships.remove(relation)

def _validate_table_exists(self, table_name):
if table_name not in self.tables:
raise InvalidMetadataError(f"Unknown table name ('{table_name}').")
Expand Down Expand Up @@ -333,6 +368,79 @@ def _log_detected_table(single_table_metadata):
table_json = json.dumps(table_dict, indent=4)
LOGGER.info(f'Detected metadata:\n{table_json}')

def _validate_all_tables_connected(self, parent_map, child_map):
"""Get the connection status of all tables.

Args:
parent_map (dict):
Dictionary mapping each parent table to its child tables.
child_map (dict):
Dictionary mapping each child table to its parent tables.

Returns:
dict specifying whether each table is connected the other tables.
"""
nodes = list(self.tables.keys())
if len(nodes) == 1:
return

parent_nodes = list(parent_map.keys())
queue = [parent_nodes[0]] if parent_map else []
connected = {table_name: False for table_name in nodes}

while queue:
node = queue.pop()
connected[node] = True
for child in list(child_map[node]) + list(parent_map[node]):
if not connected[child] and child not in queue:
queue.append(child)

if not all(connected.values()):
disconnected_tables = [table for table, value in connected.items() if not value]
if len(disconnected_tables) > 1:
table_msg = (
f'Tables {disconnected_tables} are not connected to any of the other tables.'
)
else:
table_msg = (
f'Table {disconnected_tables} is not connected to any of the other tables.'
)

raise InvalidMetadataError(
f'The relationships in the dataset are disjointed. {table_msg}')

def _detect_relationships(self):
"""Automatically detect relationships between tables."""
for parent_candidate in self.tables.keys():
primary_key = self.tables[parent_candidate].primary_key
for child_candidate in self.tables.keys() - {parent_candidate}:
child_meta = self.tables[child_candidate]
if primary_key in child_meta.columns.keys():
try:
original_foreign_key_sdtype = child_meta.columns[primary_key]['sdtype']
if original_foreign_key_sdtype != 'id':
self.update_column(child_candidate, primary_key, sdtype='id')

self.add_relationship(
parent_candidate,
child_candidate,
primary_key,
primary_key
)
except InvalidMetadataError:
self.update_column(child_candidate,
primary_key,
sdtype=original_foreign_key_sdtype)
continue

try:
self._validate_all_tables_connected(self._get_parent_map(), self._get_child_map())
except InvalidMetadataError as invalid_error:
warning_msg = (
f'Could not automatically add relationships for all tables. {str(invalid_error)}'
)
warnings.warn(warning_msg)

def detect_table_from_dataframe(self, table_name, data):
"""Detect the metadata for a table from a dataframe.

Expand Down Expand Up @@ -361,6 +469,8 @@ def detect_from_dataframes(self, data):
for table_name, dataframe in data.items():
self.detect_table_from_dataframe(table_name, dataframe)

self._detect_relationships()

def detect_table_from_csv(self, table_name, filepath):
"""Detect the metadata for a table from a csv file.

Expand Down Expand Up @@ -399,6 +509,8 @@ def detect_from_csvs(self, folder_name):
table_name = csv_file.stem
self.detect_table_from_csv(table_name, str(csv_file))

self._detect_relationships()

def set_primary_key(self, table_name, column_name):
"""Set the primary key of a table.

Expand Down Expand Up @@ -465,32 +577,6 @@ def _validate_single_table(self, errors):
'The following errors were found in the metadata:\n', title)
errors.append(error)

def _validate_all_tables_connected(self, parent_map, child_map):
nodes = list(self.tables.keys())
queue = [nodes[0]]
connected = {table_name: False for table_name in nodes}

while queue:
node = queue.pop()
connected[node] = True
for child in list(child_map[node]) + list(parent_map[node]):
if not connected[child] and child not in queue:
queue.append(child)

if not all(connected.values()):
disconnected_tables = [table for table, value in connected.items() if not value]
if len(disconnected_tables) > 1:
table_msg = (
f'Tables {disconnected_tables} are not connected to any of the other tables.'
)
else:
table_msg = (
f'Table {disconnected_tables} is not connected to any of the other tables.'
)

raise InvalidMetadataError(
f'The relationships in the dataset are disjointed. {table_msg}')

def _append_relationships_errors(self, errors, method, *args, **kwargs):
try:
method(*args, **kwargs)
Expand All @@ -500,15 +586,6 @@ def _append_relationships_errors(self, errors, method, *args, **kwargs):

errors.append(error)

def _get_parent_map(self):
parent_map = defaultdict(set)
for relation in self.relationships:
parent_name = relation['parent_table_name']
child_name = relation['child_table_name']
parent_map[child_name].add(parent_name)

return parent_map

def validate(self):
"""Validate the metadata.

Expand Down
109 changes: 100 additions & 9 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from sdv.metadata.visualization import (
create_columns_node, create_summarized_columns_node, visualize_graph)
from sdv.utils import (
cast_to_iterable, format_invalid_values_string, is_boolean_type, is_datetime_type,
is_numerical_type, load_data_from_csv, validate_datetime_format)
cast_to_iterable, format_invalid_values_string, get_datetime_format, is_boolean_type,
is_datetime_type, is_numerical_type, load_data_from_csv, validate_datetime_format)

LOGGER = logging.getLogger(__name__)

Expand All @@ -32,13 +32,11 @@ class SingleTableMetadata:
'categorical': frozenset(['order', 'order_by']),
'boolean': frozenset([]),
'id': frozenset(['regex_format']),
'unknown': frozenset(['pii']),
}

_DTYPES_TO_SDTYPES = {
'i': 'numerical',
'f': 'numerical',
'O': 'categorical',
'b': 'boolean',
'b': 'categorical',
'M': 'datetime',
}

Expand Down Expand Up @@ -252,11 +250,104 @@ def to_dict(self):

return deepcopy(metadata)

def _determine_sdtype_for_numbers(self, data):
"""Determine the sdtype for a numerical column.

Args:
data (pandas.Series):
The data to be analyzed.
"""
sdtype = 'numerical'
if len(data) > 5:
is_not_null = ~data.isna()
whole_values = (data == data.round()).loc[is_not_null].all()
positive_values = (data >= 0).loc[is_not_null].all()

unique_values = data.nunique()
unique_lt_categorical_threshold = unique_values <= min(round(len(data) / 10), 10)

if whole_values and positive_values and unique_lt_categorical_threshold:
sdtype = 'categorical'
elif unique_values == len(data) and whole_values:
sdtype = 'id'

return sdtype

def _determine_sdtype_for_objects(self, data):
"""Determine the sdtype for an object column.

Args:
data (pandas.Series):
The data to be analyzed.
"""
sdtype = None
try:
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=UserWarning)

data_test = data.sample(10000) if len(data) > 10000 else data
datetime_format = get_datetime_format(data_test)
pd.to_datetime(data_test, format=datetime_format, errors='raise')

sdtype = 'datetime'

except Exception:
if len(data) <= 5:
sdtype = 'categorical'
else:
unique_values = data.nunique()
if unique_values == len(data):
sdtype = 'id'
elif unique_values <= round(len(data) / 5):
sdtype = 'categorical'
else:
sdtype = 'unknown'

return sdtype

def _detect_columns(self, data):
"""Detect the columns' sdtype and the primary key from the data.

Args:
data (pandas.DataFrame):
The data to be analyzed.
"""
for field in data:
clean_data = data[field].dropna()
kind = clean_data.infer_objects().dtype.kind
self.columns[field] = {'sdtype': self._DTYPES_TO_SDTYPES.get(kind, 'categorical')}
column_data = data[field]
clean_data = column_data.dropna()
dtype = clean_data.infer_objects().dtype.kind

sdtype = None
if dtype in self._DTYPES_TO_SDTYPES:
sdtype = self._DTYPES_TO_SDTYPES[dtype]
elif dtype in ['i', 'f']:
sdtype = self._determine_sdtype_for_numbers(column_data)

elif dtype == 'O':
sdtype = self._determine_sdtype_for_objects(column_data)

if sdtype is None:
raise InvalidMetadataError(
f"Unsupported data type for column '{field}' (kind: {dtype})."
"The valid data types are: 'object', 'int', 'float', 'datetime', 'bool'."
)

# Set the first ID column we detect to be the primary key
if sdtype == 'id':
if self.primary_key is None:
self.primary_key = field
else:
sdtype = 'unknown'

column_dict = {'sdtype': sdtype}

if sdtype == 'unknown':
column_dict['pii'] = True
elif sdtype == 'datetime' and dtype == 'O':
datetime_format = get_datetime_format(column_data.iloc[:100])
column_dict['datetime_format'] = datetime_format

self.columns[field] = deepcopy(column_dict)

def detect_from_dataframe(self, data):
"""Detect the metadata from a ``pd.DataFrame`` object.
Expand Down
1 change: 1 addition & 0 deletions sdv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def get_datetime_format(value):

value = value[~value.isna()]
value = value.astype(str).to_numpy()

return _guess_datetime_format_for_array(value)


Expand Down
Loading