Skip to content

Commit

Permalink
Fix: PARSynthesizer not being able to conditionally sample with dat…
Browse files Browse the repository at this point in the history
…e time as context (#2347)
  • Loading branch information
pvk-developer authored Jan 22, 2025
1 parent 0ddaf6a commit ca7892c
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 14 deletions.
69 changes: 57 additions & 12 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import uuid
import warnings
from copy import deepcopy

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -74,15 +75,44 @@ def _get_context_metadata(self):
if self._sequence_key:
context_columns += self._sequence_key

for column in context_columns:
context_columns_dict[column] = self.metadata.columns[column]

for column, column_metadata in self._extra_context_columns.items():
context_columns_dict[column] = column_metadata

for column in context_columns:
context_columns_dict[column] = self.metadata.columns[column]

context_columns_dict = self._update_context_column_dict(context_columns_dict)
context_metadata_dict = {'columns': context_columns_dict}
return SingleTableMetadata.load_from_dict(context_metadata_dict)

def _update_context_column_dict(self, context_columns_dict):
"""Update context column dictionary based on available transformers.
Args:
context_columns_dict (dict):
Dictionary of context columns.
Returns:
dict:
Updated context column metadata.
"""
default_transformers_by_sdtype = deepcopy(self._data_processor._transformers_by_sdtype)
for column in self.context_columns:
column_metadata = self.metadata.columns[column]
if default_transformers_by_sdtype.get(column_metadata['sdtype']):
context_columns_dict[column] = {'sdtype': 'numerical'}

return context_columns_dict

def _get_context_columns_for_processing(self):
columns_to_be_processed = []
default_transformers_by_sdtype = deepcopy(self._data_processor._transformers_by_sdtype)
for column in self.context_columns:
if default_transformers_by_sdtype.get(self.metadata.columns[column]['sdtype']):
columns_to_be_processed.append(column)

return columns_to_be_processed

def __init__(
self,
metadata,
Expand Down Expand Up @@ -352,12 +382,6 @@ def _fit_context_model(self, transformed):
context[constant_column] = 0
context_metadata.add_column(constant_column, sdtype='numerical')

for column in self.context_columns:
# Context datetime SDTypes for PAR have already been converted to float timestamp
if context_metadata.columns[column]['sdtype'] == 'datetime':
if pd.api.types.is_numeric_dtype(context[column]):
context_metadata.update_column(column, sdtype='numerical')

with warnings.catch_warnings():
warnings.filterwarnings('ignore', message=".*The 'SingleTableMetadata' is deprecated.*")
self._context_synthesizer = GaussianCopulaSynthesizer(
Expand Down Expand Up @@ -540,9 +564,30 @@ def sample_sequential_columns(self, context_columns, sequence_length=None):
set(context_columns.columns), set(self._context_synthesizer._model.columns)
)
)
context_columns = self._process_context_columns(context_columns)

condition_columns = context_columns[condition_columns].to_dict('records')
context = self._context_synthesizer.sample_from_conditions([
Condition(conditions) for conditions in condition_columns
])
synthesizer_conditions = [Condition(conditions) for conditions in condition_columns]
context = self._context_synthesizer.sample_from_conditions(synthesizer_conditions)
context.update(context_columns)
return self._sample(context, sequence_length)

def _process_context_columns(self, context_columns):
"""Process context columns by applying appropriate transformations.
Args:
context_columns (pandas.DataFrame):
Context values containing potential columns for transformation.
Returns:
context_columns (pandas.DataFrame):
Updated context columns with transformed values.
"""
columns_to_be_processed = self._get_context_columns_for_processing()

if columns_to_be_processed:
context_columns[columns_to_be_processed] = self._data_processor.transform(
context_columns[columns_to_be_processed]
)

return context_columns
9 changes: 8 additions & 1 deletion tests/integration/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def _get_par_data_and_metadata():
'column2': ['b', 'a', 'a', 'c'],
'entity': [1, 1, 2, 2],
'context': ['a', 'a', 'b', 'b'],
'context_date': [date, date, date, date],
})
metadata = Metadata.detect_from_dataframes({'table': data})
metadata.update_column('entity', 'table', sdtype='id')
Expand Down Expand Up @@ -94,15 +95,21 @@ def test_column_after_date_complex():
data, metadata = _get_par_data_and_metadata()

# Run
model = PARSynthesizer(metadata=metadata, context_columns=['context'], epochs=1)
model = PARSynthesizer(metadata=metadata, context_columns=['context', 'context_date'], epochs=1)
model.fit(data)
sampled = model.sample(2)
context_columns = data[['context', 'context_date']]
sample_with_conditions = model.sample_sequential_columns(context_columns=context_columns)

# Assert
assert sampled.shape == data.shape
assert (sampled.dtypes == data.dtypes).all()
assert (sampled.notna().sum(axis=1) != 0).all()

expected_date = datetime.datetime.strptime('2020-01-01', '%Y-%m-%d')
assert all(sample_with_conditions['context_date'] == expected_date)
assert all(sample_with_conditions['context'].isin(['a', 'b']))


def test_save_and_load(tmp_path):
"""Test that synthesizers can be saved and loaded properly."""
Expand Down
69 changes: 68 additions & 1 deletion tests/unit/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def test__fit_context_model_with_datetime_context_column(self, gaussian_copula_m
par = PARSynthesizer(metadata, context_columns=['time'])
initial_synthesizer = Mock()
context_metadata = SingleTableMetadata.load_from_dict({
'columns': {'time': {'sdtype': 'datetime'}, 'name': {'sdtype': 'id'}}
'columns': {'time': {'sdtype': 'numerical'}, 'name': {'sdtype': 'id'}}
})
par._context_synthesizer = initial_synthesizer
par._get_context_metadata = Mock()
Expand Down Expand Up @@ -934,6 +934,8 @@ def test_sample_sequential_columns(self):
"""Test that the method uses the provided context columns to sample."""
# Setup
par = PARSynthesizer(metadata=self.get_metadata(), context_columns=['gender'])
par._process_context_columns = Mock()
par._process_context_columns.side_effect = lambda value: value
par._context_synthesizer = Mock()
par._context_synthesizer._model.columns = ['gender', 'extra_col']
par._context_synthesizer.sample_from_conditions.return_value = pd.DataFrame({
Expand Down Expand Up @@ -970,6 +972,7 @@ def test_sample_sequential_columns(self):
call_args, _ = par._sample.call_args
pd.testing.assert_frame_equal(call_args[0], expected_call_arg)
assert call_args[1] == 5
par._process_context_columns.assert_called_once_with(context_columns)

def test_sample_sequential_columns_no_context_columns(self):
"""Test that the method raises an error if the synthesizer has no context columns.
Expand Down Expand Up @@ -1083,3 +1086,67 @@ def test___init__with_unified_metadata(self):

with pytest.raises(InvalidMetadataError, match=error_msg):
PARSynthesizer(multi_metadata)

def test_sample_sequential_columns_with_datetime_values(self):
"""Test that the method converts datetime values to numerical space before sampling."""
# Setup
par = PARSynthesizer(metadata=self.get_metadata(), context_columns=['time'])
data = self.get_data()
par.fit(data)

par._context_synthesizer = Mock()
par._context_synthesizer._model.columns = ['time', 'extra_col']
par._context_synthesizer.sample_from_conditions.return_value = pd.DataFrame({
'id_col': ['A', 'A', 'A'],
'time': ['2020-01-01', '2020-01-02', '2020-01-03'],
'extra_col': [0, 1, 1],
})
par._sample = Mock()
context_columns = pd.DataFrame({
'id_col': ['ID-1', 'ID-2', 'ID-3'],
'time': ['2020-01-01', '2020-01-02', '2020-01-03'],
})

# Run
par.sample_sequential_columns(context_columns, 5)

# Assert
time_values = par._data_processor.transform(
pd.DataFrame({'time': ['2020-01-01', '2020-01-02', '2020-01-03']})
)

time_values = time_values['time'].tolist()
expected_conditions = [
Condition({'time': time_values[0]}),
Condition({'time': time_values[1]}),
Condition({'time': time_values[2]}),
]
call_args, _ = par._context_synthesizer.sample_from_conditions.call_args

assert len(call_args[0]) == len(expected_conditions)
for arg, expected in zip(call_args[0], expected_conditions):
assert arg.column_values == expected.column_values
assert arg.num_rows == expected.num_rows

def test__process_context_columns(self):
"""Test that the method processes specified columns using appropriate transformations."""
# Setup
instance = Mock()
instance._get_context_columns_for_processing.return_value = ['datetime_col']
instance._data_processor.transform.return_value = pd.DataFrame({'datetime_col': [1, 2, 3]})

context_columns = pd.DataFrame({
'datetime_col': ['2021-01-01', '2022-01-01', '2023-01-01'],
'col2': [4, 5, 6],
})

expected_result = pd.DataFrame({
'datetime_col': [1, 2, 3],
'col2': [4, 5, 6],
})

# Run
result = PARSynthesizer._process_context_columns(instance, context_columns)

# Assert
pd.testing.assert_frame_equal(result, expected_result)

0 comments on commit ca7892c

Please sign in to comment.