From e5e879757554556e0740b390c1d421cc771b61aa Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Tue, 7 May 2024 09:41:33 +0100 Subject: [PATCH 1/3] def --- sdv/single_table/base.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 60656a4a2..0b6fa521b 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -23,7 +23,8 @@ _groupby_list, check_sdv_versions_and_warn, check_synthesizer_version, generate_synthesizer_id) from sdv.constraints.errors import AggregateConstraintsError from sdv.data_processing.data_processor import DataProcessor -from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError +from sdv.errors import ( + ConstraintsNotMetError, InvalidDataError, SamplingError, SynthesizerInputError) from sdv.logging.utils import get_sdv_logger from sdv.single_table.utils import check_num_rows, handle_sampling_error, validate_file_path @@ -871,6 +872,12 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file pandas.DataFrame: Sampled data. """ + if not self._fitted: + raise SamplingError( + 'This synthesizer has not been fitted. Please fit your synthesizer first before' + ' sampling synthetic data.' + ) + sample_timestamp = datetime.datetime.now() has_constraints = bool(self._data_processor._constraints) has_batches = batch_size is not None and batch_size != num_rows From e70ac40bb3b3597165527d87bc158c7832cf7193 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Tue, 7 May 2024 09:42:05 +0100 Subject: [PATCH 2/3] unit test --- tests/unit/single_table/test_base.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 197141e69..932fa52da 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -12,7 +12,7 @@ from sdv import version from sdv.constraints.errors import AggregateConstraintsError -from sdv.errors import ConstraintsNotMetError, SynthesizerInputError, VersionError +from sdv.errors import ConstraintsNotMetError, SamplingError, SynthesizerInputError, VersionError from sdv.metadata.single_table import SingleTableMetadata from sdv.sampling.tabular import Condition from sdv.single_table import ( @@ -1399,6 +1399,20 @@ def test__sample_with_progress_bar_removing_temp_file( mock_os.remove.assert_called_once_with('.sample.csv.temp') mock_os.path.exists.assert_called_once_with('.sample.csv.temp') + def test_sample_not_fitted(self): + """Test that ``sample`` raises an error when the synthesizer is not fitted.""" + # Setup + instance = Mock() + instance._fitted = False + expected_message = re.escape( + 'This synthesizer has not been fitted. Please fit your synthesizer first before' + ' sampling synthetic data.' + ) + + # Run and Assert + with pytest.raises(SamplingError, match=expected_message): + BaseSingleTableSynthesizer.sample(instance, 10) + @patch('sdv.single_table.base.datetime') def test_sample(self, mock_datetime, caplog): """Test that we use ``_sample_with_progress_bar`` in this method.""" From e80a160ae5a122f5f8bdb5d837ca08e70ae1c8b5 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Tue, 7 May 2024 09:42:16 +0100 Subject: [PATCH 3/3] integration test --- tests/integration/single_table/test_base.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 8c7ea2601..d0c105987 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -13,7 +13,7 @@ from sdv import version from sdv.datasets.demo import download_demo -from sdv.errors import SynthesizerInputError, VersionError +from sdv.errors import SamplingError, SynthesizerInputError, VersionError from sdv.metadata import SingleTableMetadata from sdv.sampling import Condition from sdv.single_table import ( @@ -855,3 +855,19 @@ def test_synthesizer_logger(mock_datetime, mock_generate_id): ' Total number of columns: 3\n' ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' ) + + +@pytest.mark.parametrize('synthesizer', SYNTHESIZERS) +def test_sample_not_fitted(synthesizer): + """Test that a synthesizer raises an error when trying to sample without fitting.""" + # Setup + metadata = SingleTableMetadata() + synthesizer = synthesizer.__class__(metadata) + expected_message = re.escape( + 'This synthesizer has not been fitted. Please fit your synthesizer first before' + ' sampling synthetic data.' + ) + + # Run and Assert + with pytest.raises(SamplingError, match=expected_message): + synthesizer.sample(10)