Skip to content

Commit

Permalink
Make functions more generalized
Browse files Browse the repository at this point in the history
  • Loading branch information
pvk-developer committed Jan 22, 2025
1 parent b35f314 commit 1127172
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 30 deletions.
62 changes: 42 additions & 20 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import uuid
import warnings
from copy import deepcopy

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -74,25 +75,44 @@ def _get_context_metadata(self):
if self._sequence_key:
context_columns += self._sequence_key

for column in context_columns:
context_columns_dict[column] = self.metadata.columns[column]
# Context datetime SDTypes for PAR have already been converted to float timestamp
if context_columns_dict[column]['sdtype'] == 'datetime':
context_columns_dict[column] = {'sdtype': 'numerical'}

for column, column_metadata in self._extra_context_columns.items():
context_columns_dict[column] = column_metadata

for column in context_columns:
context_columns_dict[column] = self.metadata.columns[column]

context_columns_dict = self._update_context_column_dict(context_columns_dict)
context_metadata_dict = {'columns': context_columns_dict}
return SingleTableMetadata.load_from_dict(context_metadata_dict)

def _get_context_datetime_columns(self):
datetime_columns = []
def _update_context_column_dict(self, context_columns_dict):
"""Update context column dictionary based on available transformers.
Args:
context_columns_dict (dict):
Dictionary of context columns.
Returns:
dict:
Updated context column metadata.
"""
default_transformers_by_sdtype = deepcopy(self._data_processor._transformers_by_sdtype)
for column in self.context_columns:
if self.metadata.columns[column]['sdtype'] == 'datetime':
datetime_columns.append(column)
column_metadata = self.metadata.columns[column]
sdtype = column_metadata['sdtype']
if default_transformers_by_sdtype.get(column_metadata['sdtype']):
context_columns_dict[column] = {'sdtype': 'numerical'}

return datetime_columns
return context_columns_dict

def _get_context_columns_for_processing(self):
columns_to_be_processed = []
default_transformers_by_sdtype = deepcopy(self._data_processor._transformers_by_sdtype)
for column in self.context_columns:
if default_transformers_by_sdtype.get(self.metadata.columns[column]['sdtype']):
columns_to_be_processed.append(column)

return columns_to_be_processed

def __init__(
self,
Expand Down Expand Up @@ -545,28 +565,30 @@ def sample_sequential_columns(self, context_columns, sequence_length=None):
set(context_columns.columns), set(self._context_synthesizer._model.columns)
)
)
context_columns = self._process_datetime_columns_in_context_columns(context_columns)
context_columns = self._process_context_columns(context_columns)

condition_columns = context_columns[condition_columns].to_dict('records')
synthesizer_conditions = [Condition(conditions) for conditions in condition_columns]
context = self._context_synthesizer.sample_from_conditions(synthesizer_conditions)
context.update(context_columns)
return self._sample(context, sequence_length)

def _process_datetime_columns_in_context_columns(self, context_columns):
"""Process datetime columns by transforming them using the data processor.
def _process_context_columns(self, context_columns):
"""Process context columns by applying appropriate transformations.
Args:
context_columns (pandas.DataFrame):
Context values containing potential datetime columns.
Context values containing potential columns for transformation.
Returns:
context_columns (pandas.DataFrame):
Updated context columns with transformed datetime values.
Updated context columns with transformed values.
"""
datetime_columns = self._get_context_datetime_columns()
if datetime_columns:
transformed = self._data_processor.transform(context_columns[datetime_columns])
context_columns[datetime_columns] = transformed[datetime_columns]
columns_to_be_processed = self._get_context_columns_for_processing()

if columns_to_be_processed:
context_columns[columns_to_be_processed] = self._data_processor.transform(
context_columns[columns_to_be_processed]
)

return context_columns
22 changes: 12 additions & 10 deletions tests/unit/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,8 @@ def test_sample_sequential_columns(self):
"""Test that the method uses the provided context columns to sample."""
# Setup
par = PARSynthesizer(metadata=self.get_metadata(), context_columns=['gender'])
par._get_context_datetime_columns = Mock(return_value=None)
par._process_context_columns = Mock()
par._process_context_columns.side_effect = lambda value: value
par._context_synthesizer = Mock()
par._context_synthesizer._model.columns = ['gender', 'extra_col']
par._context_synthesizer.sample_from_conditions.return_value = pd.DataFrame({
Expand Down Expand Up @@ -971,7 +972,7 @@ def test_sample_sequential_columns(self):
call_args, _ = par._sample.call_args
pd.testing.assert_frame_equal(call_args[0], expected_call_arg)
assert call_args[1] == 5
par._get_context_datetime_columns.assert_called_once_with()
par._process_context_columns.assert_called_once_with(context_columns)

def test_sample_sequential_columns_no_context_columns(self):
"""Test that the method raises an error if the synthesizer has no context columns.
Expand Down Expand Up @@ -1127,24 +1128,25 @@ def test_sample_sequential_columns_with_datetime_values(self):
assert arg.column_values == expected.column_values
assert arg.num_rows == expected.num_rows

def test__process_datetime_columns_in_context_columns(self):
"""Test that the method converts datetime columns into numerical space."""
def test__process_context_columns(self):
"""Test that the method processes specified columns using appropriate transformations."""
# Setup
instance = Mock()
instance._get_context_datetime_columns.return_value = ['Date']
instance._get_context_columns_for_processing.return_value = ['datetime_col']
instance._data_processor.transform.return_value = pd.DataFrame({'datetime_col': [1, 2, 3]})
instance._get_context_datetime_columns.return_value = ['datetime_col']

context_columns = pd.DataFrame({
'datetime_col': ['2021-01-01', '2022-01-01', '2023-01-01'],
'col2': [4, 5, 6],
})

expected_result = pd.DataFrame({
'datetime_col': [1, 2, 3],
'col2': [4, 5, 6],
})

# Run
result = PARSynthesizer._process_datetime_columns_in_context_columns(
instance, context_columns
)
result = PARSynthesizer._process_context_columns(instance, context_columns)

# Assert
expected_result = pd.DataFrame({'datetime_col': [1, 2, 3], 'col2': [4, 5, 6]})
pd.testing.assert_frame_equal(result, expected_result)

0 comments on commit 1127172

Please sign in to comment.