Skip to content

Commit

Permalink
Multi-parent re-model and re-sample issue (#298)
Browse files Browse the repository at this point in the history
* Avoid modeling and sampling children tables multiple times

* Ensure multi-foreign key schemas are supported

* Compare dtypes by kind
  • Loading branch information
csala authored Jan 20, 2021
1 parent 8f41bc8 commit 313c8b8
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 93 deletions.
24 changes: 11 additions & 13 deletions sdv/metadata/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,10 @@ def get_primary_key(self, table_name):
"""
return self.get_table_meta(table_name).get('primary_key')

def get_foreign_key(self, parent, child):
"""Get the name of the field in the child that is a foreign key to parent.
def get_foreign_keys(self, parent, child):
"""Get the name of all the fields in the child that are foreign keys to this parent.
If there is no relationship between the two tables, a ``ValueError`` is raised.
If there is no relationship between the two tables an empty list is returned.
Args:
parent (str):
Expand All @@ -311,19 +311,16 @@ def get_foreign_key(self, parent, child):
Name of the child table.
Returns:
str or None:
Foreign key field name.
Raises:
ValueError:
If the relationship does not exist.
list[str]:
List of foreign key names.
"""
foreign_keys = []
for name, field in self.get_fields(child).items():
ref = field.get('ref')
if ref and ref['table'] == parent:
return name
foreign_keys.append(name)

raise ValueError('{} is not parent of {}'.format(parent, child))
return foreign_keys

def load_table(self, table_name):
"""Load the data of the indicated table as a DataFrame.
Expand Down Expand Up @@ -396,7 +393,7 @@ def get_dtypes(self, table_name, ids=False):
if ids and field_type == 'id':
if (name != table_meta.get('primary_key')) and not field.get('ref'):
for child_table in self.get_children(table_name):
if name == self.get_foreign_key(table_name, child_table):
if name in self.get_foreign_keys(table_name, child_table):
break

if ids or (field_type != 'id'):
Expand Down Expand Up @@ -967,11 +964,12 @@ def __repr__(self):
tables = self.get_tables()
relationships = [
' {}.{} -> {}.{}'.format(
table, self.get_foreign_key(parent, table),
table, foreign_key,
parent, self.get_primary_key(parent)
)
for table in tables
for parent in list(self.get_parents(table))
for foreign_key in self.get_foreign_keys(parent, table)
]

return (
Expand Down
16 changes: 10 additions & 6 deletions sdv/metadata/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def _add_nodes(metadata, digraph):

parents = metadata.get_parents(table)
for parent in parents:
foreign_key = metadata.get_foreign_key(parent, table)
extras.append('Foreign key ({}): {}'.format(parent, foreign_key))
for foreign_key in metadata.get_foreign_keys(parent, table):
extras.append('Foreign key ({}): {}'.format(parent, foreign_key))

path = metadata.get_table_meta(table).get('path')
if path is not None:
Expand All @@ -80,13 +80,17 @@ def _add_edges(metadata, digraph):
"""
for table in metadata.get_tables():
for parent in list(metadata.get_parents(table)):
label = '\n'.join([
' {}.{} -> {}.{}'.format(
table, foreign_key,
parent, metadata.get_primary_key(parent)
)
for foreign_key in metadata.get_foreign_keys(parent, table)
])
digraph.edge(
parent,
table,
label=' {}.{} -> {}.{}'.format(
table, metadata.get_foreign_key(parent, table),
parent, metadata.get_primary_key(parent)
),
label=label,
arrowhead='oinv'
)

Expand Down
137 changes: 81 additions & 56 deletions sdv/relational/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,33 @@ def _get_extension(self, child_name, child_table, foreign_key):
model.fit(child_rows.reset_index(drop=True))
row = model.get_parameters()
row = pd.Series(row)
row.index = '__' + child_name + '__' + row.index
row.index = f'__{child_name}__{foreign_key}__' + row.index
extension_rows.append(row)

return pd.DataFrame(extension_rows, index=foreign_key_values)

def _load_table(self, tables, table_name):
if tables:
table = tables[table_name].copy()
else:
table = self.metadata.load_table(table_name)
tables[table_name] = table

return table

def _extend_table(self, table, tables, table_name):
LOGGER.info('Computing extensions for table %s', table_name)
for child_name in self.metadata.get_children(table_name):
child_key = self.metadata.get_foreign_key(table_name, child_name)
child_table = self._model_table(child_name, tables, child_key)
extension = self._get_extension(child_name, child_table, child_key)
table = table.merge(extension, how='left', right_index=True, left_index=True)
table['__' + child_name + '__num_rows'].fillna(0, inplace=True)
if child_name not in self._models:
child_table = self._model_table(child_name, tables)
else:
child_table = tables[child_name]

foreign_keys = self.metadata.get_foreign_keys(table_name, child_name)
for index, foreign_key in enumerate(foreign_keys):
extension = self._get_extension(child_name, child_table, foreign_key)
table = table.merge(extension, how='left', right_index=True, left_index=True)
table[f'__{child_name}__{foreign_key}__num_rows'].fillna(0, inplace=True)

return table

Expand All @@ -116,6 +130,12 @@ def _prepare_for_modeling(self, table_data, table_name, primary_key):
table_meta['primary_key'] = None
del table_meta['fields'][primary_key]

keys = {}
for name, field in list(fields.items()):
if field['type'] == 'id':
keys[name] = table_data.pop(name).values
del fields[name]

for column in table_data.columns:
if column not in fields:
fields[column] = {
Expand All @@ -131,43 +151,32 @@ def _prepare_for_modeling(self, table_data, table_name, primary_key):

table_data[column] = table_data[column].fillna(fill_value)

return table_meta
return table_meta, keys

def _model_table(self, table_name, tables, foreign_key=None):
def _model_table(self, table_name, tables):
"""Model the indicated table and its children.
Args:
table_name (str):
Name of the table to model.
tables (dict):
Dict of original tables.
foreign_key (str):
Name of the foreign key that references this table. Used only when modeling
a child table.
Returns:
pandas.DataFrame:
table data with the extensions created while modeling its children.
"""
LOGGER.info('Modeling %s', table_name)

if tables:
table = tables[table_name].copy()
else:
table = self.metadata.load_table(table_name)

table = self._load_table(tables, table_name)
self._table_sizes[table_name] = len(table)

primary_key = self.metadata.get_primary_key(table_name)
if primary_key:
table = table.set_index(primary_key)
table = self._extend_table(table, tables, table_name)

table_meta = self._prepare_for_modeling(table, table_name, primary_key)

if foreign_key:
foreign_key_values = table.pop(foreign_key).values
del table_meta['fields'][foreign_key]
table_meta, keys = self._prepare_for_modeling(table, table_name, primary_key)

LOGGER.info('Fitting %s for table %s; shape: %s', self._model.__name__,
table_name, table.shape)
Expand All @@ -178,8 +187,10 @@ def _model_table(self, table_name, tables, foreign_key=None):
if primary_key:
table.reset_index(inplace=True)

if foreign_key:
table[foreign_key] = foreign_key_values
for name, values in keys.items():
table[name] = values

tables[table_name] = table

return table

Expand All @@ -193,6 +204,10 @@ def _fit(self, tables=None):
indicated in ``metadata``. Defaults to ``None``.
"""
self.metadata.validate(tables)
if tables:
tables = tables.copy()
else:
tables = {}

for table_name in self.metadata.get_tables():
if not self.metadata.get_parents(table_name):
Expand Down Expand Up @@ -224,27 +239,34 @@ def _finalize(self, sampled_data):
parents = self.metadata.get_parents(table_name)
if parents:
for parent_name in parents:
foreign_key = self.metadata.get_foreign_key(parent_name, table_name)
if foreign_key not in table_rows:
parent_ids = self._find_parent_ids(table_name, parent_name, sampled_data)
table_rows[foreign_key] = parent_ids
foreign_keys = self.metadata.get_foreign_keys(parent_name, table_name)
for foreign_key in foreign_keys:
if foreign_key not in table_rows:
parent_ids = self._find_parent_ids(
table_name, parent_name, foreign_key, sampled_data)
table_rows[foreign_key] = parent_ids

fields = self.metadata.get_fields(table_name)
dtypes = self.metadata.get_dtypes(table_name, ids=True)
for name, dtype in dtypes.items():
table_rows[name] = table_rows[name].dropna().astype(dtype)

final_data[table_name] = table_rows[list(fields.keys())]
final_data[table_name] = table_rows[list(dtypes.keys())]

return final_data

def _extract_parameters(self, parent_row, table_name):
def _extract_parameters(self, parent_row, table_name, foreign_key):
"""Get the params from a generated parent row.
Args:
parent_row (pandas.Series):
A generated parent row.
table_name (str):
Name of the table to make the model for.
foreign_key (str):
Name of the foreign key used to form this
parent child relationship.
"""
prefix = '__{}__'.format(table_name)
prefix = f'__{table_name}__{foreign_key}__'
keys = [key for key in parent_row.keys() if key.startswith(prefix)]
new_keys = {key: key[len(prefix):] for key in keys}
flat_parameters = parent_row[keys]
Expand Down Expand Up @@ -274,15 +296,15 @@ def _sample_rows(self, model, table_name, num_rows=None):
return sampled

def _sample_children(self, table_name, sampled_data, table_rows=None):
if table_rows is None:
table_rows = sampled_data[table_name]

for child_name in self.metadata.get_children(table_name):
for _, row in table_rows.iterrows():
self._sample_child_rows(child_name, table_name, row, sampled_data)
if child_name not in sampled_data:
LOGGER.info('Sampling rows from child table %s', child_name)
for _, row in table_rows.iterrows():
self._sample_child_rows(child_name, table_name, row, sampled_data)

def _sample_child_rows(self, table_name, parent_name, parent_row, sampled_data):
parameters = self._extract_parameters(parent_row, table_name)
foreign_key = self.metadata.get_foreign_keys(parent_name, table_name)[0]
parameters = self._extract_parameters(parent_row, table_name, foreign_key)

table_meta = self._models[table_name].get_metadata()
model = self._model(table_metadata=table_meta)
Expand All @@ -291,7 +313,6 @@ def _sample_child_rows(self, table_name, parent_name, parent_row, sampled_data):
table_rows = self._sample_rows(model, table_name)
if not table_rows.empty:
parent_key = self.metadata.get_primary_key(parent_name)
foreign_key = self.metadata.get_foreign_key(parent_name, table_name)
table_rows[foreign_key] = parent_row[parent_key]

previous = sampled_data.get(table_name)
Expand Down Expand Up @@ -323,10 +344,10 @@ def _find_parent_id(likelihoods, num_rows):

return np.random.choice(likelihoods.index, p=weights)

def _get_likelihoods(self, table_rows, parent_rows, table_name):
def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key):
likelihoods = dict()
for parent_id, row in parent_rows.iterrows():
parameters = self._extract_parameters(row, table_name)
parameters = self._extract_parameters(row, table_name, foreign_key)
table_meta = self._models[table_name].get_metadata()
model = self._model(table_metadata=table_meta)
model.set_parameters(parameters)
Expand All @@ -337,7 +358,7 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name):

return pd.DataFrame(likelihoods, index=table_rows.index)

def _find_parent_ids(self, table_name, parent_name, sampled_data):
def _find_parent_ids(self, table_name, parent_name, foreign_key, sampled_data):
table_rows = sampled_data[table_name]
if parent_name in sampled_data:
parent_rows = sampled_data[parent_name]
Expand All @@ -349,29 +370,29 @@ def _find_parent_ids(self, table_name, parent_name, sampled_data):

primary_key = self.metadata.get_primary_key(parent_name)
parent_rows = parent_rows.set_index(primary_key)
num_rows = parent_rows['__' + table_name + '__num_rows'].fillna(0).clip(0)
num_rows = parent_rows[f'__{table_name}__{foreign_key}__num_rows'].fillna(0).clip(0)

likelihoods = self._get_likelihoods(table_rows, parent_rows, table_name)
likelihoods = self._get_likelihoods(table_rows, parent_rows, table_name, foreign_key)
return likelihoods.apply(self._find_parent_id, axis=1, num_rows=num_rows)

def _sample_table(self, table_name, num_rows=None, sample_children=True):
def _sample_table(self, table_name, num_rows=None, sample_children=True, sampled_data=None):
"""Sample a single table and optionally its children."""
if sampled_data is None:
sampled_data = {}

if num_rows is None:
num_rows = self._table_sizes[table_name]

LOGGER.info('Sampling %s rows from table %s', num_rows, table_name)

model = self._models[table_name]
table_rows = self._sample_rows(model, table_name, num_rows)
sampled_data[table_name] = table_rows

if sample_children:
sampled_data = {
table_name: table_rows
}

self._sample_children(table_name, sampled_data)
return self._finalize(sampled_data)
self._sample_children(table_name, sampled_data, table_rows)

else:
return self._finalize({table_name: table_rows})[table_name]
return sampled_data

def _sample(self, table_name=None, num_rows=None, sample_children=True):
"""Sample the entire dataset.
Expand Down Expand Up @@ -400,12 +421,16 @@ def _sample(self, table_name=None, num_rows=None, sample_children=True):
A ``NotFittedError`` is raised when the ``SDV`` instance has not been fitted yet.
"""
if table_name:
return self._sample_table(table_name, num_rows, sample_children)
sampled_data = self._sample_table(table_name, num_rows, sample_children)
sampled_data = self._finalize(sampled_data)
if sample_children:
return sampled_data

return sampled_data[table_name]

sampled_data = dict()
for table in self.metadata.get_tables():
if not self.metadata.get_parents(table):
sampled = self._sample_table(table, num_rows)
sampled_data.update(sampled)
self._sample_table(table, num_rows, sampled_data=sampled_data)

return sampled_data
return self._finalize(sampled_data)
Empty file added tests/__init__.py
Empty file.
22 changes: 22 additions & 0 deletions tests/integration/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pandas as pd

from sdv import Metadata


def load_multi_foreign_key():
parent = pd.DataFrame({
'parent_id': range(10),
'value': range(10)
})
child = pd.DataFrame({
'parent_1_id': range(10),
'parent_2_id': range(10),
'value': range(10)
})

metadata = Metadata()
metadata.add_table('parent', parent, primary_key='parent_id')
metadata.add_table('child', child, parent='parent', foreign_key='parent_1_id')
metadata.add_relationship('parent', 'child', 'parent_2_id')

return metadata, {'parent': parent, 'child': child}
Loading

0 comments on commit 313c8b8

Please sign in to comment.