Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error message when trying to sample before fitting (single table) #1992

Merged
merged 3 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
_groupby_list, check_sdv_versions_and_warn, check_synthesizer_version, generate_synthesizer_id)
from sdv.constraints.errors import AggregateConstraintsError
from sdv.data_processing.data_processor import DataProcessor
from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError
from sdv.errors import (
ConstraintsNotMetError, InvalidDataError, SamplingError, SynthesizerInputError)
from sdv.logging.utils import get_sdv_logger
from sdv.single_table.utils import check_num_rows, handle_sampling_error, validate_file_path

Expand Down Expand Up @@ -871,6 +872,12 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file
pandas.DataFrame:
Sampled data.
"""
if not self._fitted:
raise SamplingError(
'This synthesizer has not been fitted. Please fit your synthesizer first before'
' sampling synthetic data.'
)

sample_timestamp = datetime.datetime.now()
has_constraints = bool(self._data_processor._constraints)
has_batches = batch_size is not None and batch_size != num_rows
Expand Down
18 changes: 17 additions & 1 deletion tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from sdv import version
from sdv.datasets.demo import download_demo
from sdv.errors import SynthesizerInputError, VersionError
from sdv.errors import SamplingError, SynthesizerInputError, VersionError
from sdv.metadata import SingleTableMetadata
from sdv.sampling import Condition
from sdv.single_table import (
Expand Down Expand Up @@ -855,3 +855,19 @@ def test_synthesizer_logger(mock_datetime, mock_generate_id):
' Total number of columns: 3\n'
' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
)


@pytest.mark.parametrize('synthesizer', SYNTHESIZERS)
def test_sample_not_fitted(synthesizer):
"""Test that a synthesizer raises an error when trying to sample without fitting."""
# Setup
metadata = SingleTableMetadata()
synthesizer = synthesizer.__class__(metadata)
expected_message = re.escape(
'This synthesizer has not been fitted. Please fit your synthesizer first before'
' sampling synthetic data.'
)

# Run and Assert
with pytest.raises(SamplingError, match=expected_message):
synthesizer.sample(10)
16 changes: 15 additions & 1 deletion tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from sdv import version
from sdv.constraints.errors import AggregateConstraintsError
from sdv.errors import ConstraintsNotMetError, SynthesizerInputError, VersionError
from sdv.errors import ConstraintsNotMetError, SamplingError, SynthesizerInputError, VersionError
from sdv.metadata.single_table import SingleTableMetadata
from sdv.sampling.tabular import Condition
from sdv.single_table import (
Expand Down Expand Up @@ -1399,6 +1399,20 @@ def test__sample_with_progress_bar_removing_temp_file(
mock_os.remove.assert_called_once_with('.sample.csv.temp')
mock_os.path.exists.assert_called_once_with('.sample.csv.temp')

def test_sample_not_fitted(self):
"""Test that ``sample`` raises an error when the synthesizer is not fitted."""
# Setup
instance = Mock()
instance._fitted = False
expected_message = re.escape(
'This synthesizer has not been fitted. Please fit your synthesizer first before'
' sampling synthetic data.'
)

# Run and Assert
with pytest.raises(SamplingError, match=expected_message):
BaseSingleTableSynthesizer.sample(instance, 10)

@patch('sdv.single_table.base.datetime')
def test_sample(self, mock_datetime, caplog):
"""Test that we use ``_sample_with_progress_bar`` in this method."""
Expand Down
Loading