-
Notifications
You must be signed in to change notification settings - Fork 324
/
Copy pathbase.py
777 lines (645 loc) · 29.7 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
"""Base Multi Table Synthesizer class."""
import contextlib
import datetime
import inspect
import operator
import warnings
from collections import defaultdict
from copy import deepcopy
import cloudpickle
import numpy as np
from tqdm import tqdm
from sdv import version
from sdv._utils import (
check_sdv_versions_and_warn,
check_synthesizer_version,
generate_synthesizer_id,
)
from sdv.errors import (
ConstraintsNotMetError,
InvalidDataError,
SamplingError,
SynthesizerInputError,
)
from sdv.logging import disable_single_table_logger, get_sdv_logger
from sdv.metadata.metadata import Metadata
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.single_table.base import INT_REGEX_ZERO_ERROR_MESSAGE
from sdv.single_table.copulas import GaussianCopulaSynthesizer
SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer')
DEPRECATION_MSG = (
"The 'MultiTableMetadata' is deprecated. Please use the new 'Metadata' class for synthesizers."
)
class BaseMultiTableSynthesizer:
"""Base class for multi table synthesizers.
The ``BaseMultiTableSynthesizer`` class defines the common API that all the
multi table synthesizers need to implement, as well as common functionality.
Args:
metadata (sdv.metadata.multi_table.MultiTableMetadata):
Multi table metadata representing the data tables that this synthesizer will be used
for.
locales (list or str):
The default locale(s) to use for AnonymizedFaker transformers.
Defaults to ``['en_US']``.
verbose (bool):
Whether to print progress for fitting or not.
"""
DEFAULT_SYNTHESIZER_KWARGS = None
_synthesizer = GaussianCopulaSynthesizer
_numpy_seed = 73251
@contextlib.contextmanager
def _set_temp_numpy_seed(self):
initial_state = np.random.get_state()
if isinstance(self._numpy_seed, int):
np.random.seed(self._numpy_seed)
np.random.default_rng(self._numpy_seed)
else:
np.random.set_state(self._numpy_seed)
np.random.default_rng(self._numpy_seed[1])
try:
yield
finally:
self._numpy_seed = np.random.get_state()
np.random.set_state(initial_state)
def _initialize_models(self):
with disable_single_table_logger():
for table_name, table_metadata in self.metadata.tables.items():
synthesizer_parameters = {'locales': self.locales}
synthesizer_parameters.update(self._table_parameters.get(table_name, {}))
metadata_dict = {'tables': {table_name: table_metadata.to_dict()}}
metadata = Metadata.load_from_dict(metadata_dict)
self._table_synthesizers[table_name] = self._synthesizer(
metadata=metadata, **synthesizer_parameters
)
self._table_synthesizers[table_name]._data_processor.table_name = table_name
def _get_pbar_args(self, **kwargs):
"""Return a dictionary with the updated keyword args for a progress bar."""
pbar_args = {'disable': not self.verbose}
pbar_args.update(kwargs)
return pbar_args
def _print(self, text='', **kwargs):
if self.verbose:
print(text, **kwargs) # noqa: T201
def _check_metadata_updated(self):
if self.metadata._check_updated_flag():
self.metadata._reset_updated_flag()
warnings.warn(
"We strongly recommend saving the metadata using 'save_to_json' for replicability"
' in future SDV versions.'
)
def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None):
self.metadata = metadata
if type(metadata) is MultiTableMetadata:
warnings.warn(DEPRECATION_MSG, FutureWarning)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message=r'.*column relationship.*')
self.metadata.validate()
self._check_metadata_updated()
self.locales = locales
self.verbose = False
self.extended_columns = defaultdict(dict)
self._table_synthesizers = {}
self._table_parameters = defaultdict(dict)
self._original_table_columns = {}
if synthesizer_kwargs is not None:
warn_message = (
'The `synthesizer_kwargs` parameter is deprecated as of SDV 1.2.0 and does not '
'affect the synthesizer. Please use the `set_table_parameters` method instead.'
)
warnings.warn(warn_message, FutureWarning)
if self.DEFAULT_SYNTHESIZER_KWARGS:
for table_name in self.metadata.tables:
self._table_parameters[table_name] = deepcopy(self.DEFAULT_SYNTHESIZER_KWARGS)
self._initialize_models()
self._fitted = False
self._creation_date = datetime.datetime.today().strftime('%Y-%m-%d')
self._fitted_date = None
self._fitted_sdv_version = None
self._fitted_sdv_enterprise_version = None
self._synthesizer_id = generate_synthesizer_id(self)
SYNTHESIZER_LOGGER.info({
'EVENT': 'Instance',
'TIMESTAMP': datetime.datetime.now(),
'SYNTHESIZER CLASS NAME': self.__class__.__name__,
'SYNTHESIZER ID': self._synthesizer_id,
})
def set_address_columns(self, table_name, column_names, anonymization_level='full'):
"""Set the address multi-column transformer.
Args:
table_name (str):
The name of the table for which the address transformer should be set.
column_names (tuple[str]):
The column names to be used for the address transformer.
anonymization_level (str):
The anonymization level to use for the address transformer.
"""
self._validate_table_name(table_name)
self._table_synthesizers[table_name].set_address_columns(column_names, anonymization_level)
def get_table_parameters(self, table_name):
"""Return the parameters for the given table's synthesizer.
Args:
table_name (str):
Table name for which the parameters should be retrieved.
Returns:
parameters (dict):
A dictionary with the following structure:
{
'synthesizer_name': the string name of the synthesizer for that table,
'synthesizer_parameters': the parameters used to instantiate the synthesizer
}
"""
table_synthesizer = self._table_synthesizers.get(table_name)
if not table_synthesizer:
table_params = {'synthesizer_name': None, 'synthesizer_parameters': {}}
else:
table_params = {
'synthesizer_name': type(table_synthesizer).__name__,
'synthesizer_parameters': table_synthesizer.get_parameters(),
}
return table_params
def get_parameters(self):
"""Return the parameters used to instantiate the synthesizer and all table synthesizers.
Returns:
parameters (dict):
A dictionary representing the parameters used to instantiate the synthesizer.
"""
parameters = inspect.signature(self.__init__).parameters
instantiated_parameters = {}
for parameter_name in parameters:
if parameter_name != 'metadata':
instantiated_parameters[parameter_name] = self.__dict__.get(parameter_name)
return instantiated_parameters
def set_table_parameters(self, table_name, table_parameters):
"""Update the table's synthesizer instantiation parameters.
Args:
table_name (str):
Table name for which the parameters should be retrieved.
table_parameters (dict):
A dictionary with the parameters as keys and the values to be used to instantiate
the table's synthesizer.
"""
table_metadata = self.metadata.get_table_metadata(table_name)
self._table_synthesizers[table_name] = self._synthesizer(
metadata=table_metadata, **table_parameters
)
self._table_parameters[table_name].update(deepcopy(table_parameters))
def get_metadata(self):
"""Return the ``Metadata`` for this synthesizer."""
return Metadata.load_from_dict(self.metadata.to_dict())
def _validate_all_tables(self, data):
"""Validate every table of the data has a valid table/metadata pair."""
errors = []
for table_name, table_data in data.items():
try:
self._table_synthesizers[table_name].validate(table_data)
except InvalidDataError as error:
error_msg = f"Table: '{table_name}'"
for _error in error.errors:
error_msg += f'\nError: {_error}'
errors.append(error_msg)
except ValueError as error:
errors.append(str(error))
except KeyError:
continue
return errors
def validate(self, data):
"""Validate the data.
Validate that the metadata matches the data and thta every table's constraints are valid.
Args:
data (dict):
A dictionary of table names to pd.DataFrames.
"""
errors = []
constraints_errors = []
self.metadata.validate_data(data)
for table_name in data:
if table_name in self._table_synthesizers:
try:
self._table_synthesizers[table_name]._validate_constraints(data[table_name])
except ConstraintsNotMetError as error:
constraints_errors.append(error)
# Validate rules specific to each synthesizer
errors += self._table_synthesizers[table_name]._validate(data[table_name])
if constraints_errors:
raise ConstraintsNotMetError(constraints_errors)
elif errors:
raise InvalidDataError(errors)
def _validate_table_name(self, table_name):
if table_name not in self._table_synthesizers:
raise ValueError(
'The provided data does not match the metadata:'
f"\nTable '{table_name}' is not present in the metadata."
)
def _assign_table_transformers(self, synthesizer, table_name, table_data):
"""Update the ``synthesizer`` to ignore the foreign keys while preprocessing the data."""
synthesizer.auto_assign_transformers(table_data)
foreign_key_columns = self.metadata._get_all_foreign_keys(table_name)
column_name_to_transformers = {column_name: None for column_name in foreign_key_columns}
synthesizer.update_transformers(column_name_to_transformers)
def auto_assign_transformers(self, data):
"""Automatically assign the required transformers for the given data and constraints.
This method will automatically set a configuration to the ``rdt.HyperTransformer``
with the required transformers for the current data.
Args:
data (dict):
Mapping of table name to pandas.DataFrame.
Raises:
InvalidDataError:
If a table of the data is not present in the metadata.
"""
for table_name, table_data in data.items():
self._validate_table_name(table_name)
synthesizer = self._table_synthesizers[table_name]
self._assign_table_transformers(synthesizer, table_name, table_data)
def get_transformers(self, table_name):
"""Get a dictionary mapping of ``column_name`` and ``rdt.transformers``.
A dictionary representing the column names and the transformers that will be used
to transform the data.
Args:
table_name (string):
The name of the table of which to get the transformers.
Returns:
dict:
A dictionary mapping with column names and transformers.
Raises:
ValueError:
If ``table_name`` is not present in the metadata.
"""
self._validate_table_name(table_name)
return self._table_synthesizers[table_name].get_transformers()
def update_transformers(self, table_name, column_name_to_transformer):
"""Update any of the transformers assigned to each of the column names.
Args:
table_name (string):
The name of the table of which to update the transformers.
column_name_to_transformer (dict):
Dict mapping column names to transformers to be used for that column.
Raises:
ValueError:
If ``table_name`` is not present in the metadata.
"""
self._validate_table_name(table_name)
self._table_synthesizers[table_name].update_transformers(column_name_to_transformer)
def _store_and_convert_original_cols(self, data):
list_of_changed_tables = []
for table, dataframe in data.items():
data_columns = dataframe.columns
col_name_mapping = {str(col): col for col in data_columns}
reverse_col_name_mapping = {col: str(col) for col in data_columns}
self._original_table_columns[table] = col_name_mapping
dataframe = dataframe.rename(columns=reverse_col_name_mapping)
for column in data_columns:
if isinstance(column, int):
list_of_changed_tables.append(table)
break
data[table] = dataframe
return list_of_changed_tables
def _transform_helper(self, data):
"""Stub method for transforming data patterns."""
return data
def preprocess(self, data):
"""Transform the raw data to numerical space.
Args:
data (dict):
Dictionary mapping each table name to a ``pandas.DataFrame``.
Returns:
dict:
A dictionary with the preprocessed data.
"""
list_of_changed_tables = self._store_and_convert_original_cols(data)
data = self._transform_helper(data)
self.validate(data)
if self._fitted:
warnings.warn(
'This model has already been fitted. To use the new preprocessed data, '
"please refit the model using 'fit' or 'fit_processed_data'."
)
processed_data = {}
pbar_args = self._get_pbar_args(desc='Preprocess Tables')
for table_name, table_data in tqdm(data.items(), **pbar_args):
try:
synthesizer = self._table_synthesizers[table_name]
self._assign_table_transformers(synthesizer, table_name, table_data)
processed_data[table_name] = synthesizer._preprocess(table_data)
except SynthesizerInputError as e:
if INT_REGEX_ZERO_ERROR_MESSAGE in str(e):
raise SynthesizerInputError(
f'Primary key for table "{table_name}" {INT_REGEX_ZERO_ERROR_MESSAGE}'
)
raise e
for table in list_of_changed_tables:
data[table] = data[table].rename(columns=self._original_table_columns[table])
return processed_data
def _model_tables(self, augmented_data):
"""Model the augmented tables.
Args:
augmented_data (dict):
Dictionary mapping each table name to an augmented ``pandas.DataFrame``.
"""
raise NotImplementedError()
def _augment_tables(self, processed_data):
"""Augment the processed data.
Args:
processed_data (dict):
Dictionary mapping each table name to a preprocessed ``pandas.DataFrame``.
"""
raise NotImplementedError()
def fit_processed_data(self, processed_data):
"""Fit this model to the transformed data.
Args:
processed_data (dict):
Dictionary mapping each table name to a preprocessed ``pandas.DataFrame``.
"""
total_rows = 0
total_columns = 0
for table in processed_data.values():
total_rows += len(table)
total_columns += len(table.columns)
SYNTHESIZER_LOGGER.info({
'EVENT': 'Fit processed data',
'TIMESTAMP': datetime.datetime.now(),
'SYNTHESIZER CLASS NAME': self.__class__.__name__,
'SYNTHESIZER ID': self._synthesizer_id,
'TOTAL NUMBER OF TABLES': len(processed_data),
'TOTAL NUMBER OF ROWS': total_rows,
'TOTAL NUMBER OF COLUMNS': total_columns,
})
check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt)
with disable_single_table_logger():
augmented_data = self._augment_tables(processed_data)
self._model_tables(augmented_data)
self._fitted = True
self._fitted_date = datetime.datetime.today().strftime('%Y-%m-%d')
self._fitted_sdv_version = getattr(version, 'public', None)
self._fitted_sdv_enterprise_version = getattr(version, 'enterprise', None)
def fit(self, data):
"""Fit this model to the original data.
Args:
data (dict):
Dictionary mapping each table name to a ``pandas.DataFrame`` in the raw format
(before any transformations).
"""
total_rows = 0
total_columns = 0
for table in data.values():
total_rows += len(table)
total_columns += len(table.columns)
SYNTHESIZER_LOGGER.info({
'EVENT': 'Fit',
'TIMESTAMP': datetime.datetime.now(),
'SYNTHESIZER CLASS NAME': self.__class__.__name__,
'SYNTHESIZER ID': self._synthesizer_id,
'TOTAL NUMBER OF TABLES': len(data),
'TOTAL NUMBER OF ROWS': total_rows,
'TOTAL NUMBER OF COLUMNS': total_columns,
})
check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt)
self._check_metadata_updated()
self._fitted = False
processed_data = self.preprocess(data)
self._print(text='\n', end='')
self.fit_processed_data(processed_data)
def reset_sampling(self):
"""Reset the sampling to the state that was left right after fitting."""
self._numpy_seed = 73251
for synthesizer in self._table_synthesizers.values():
synthesizer.reset_sampling()
def _sample(self, scale):
raise NotImplementedError()
def _reverse_transform_helper(self, sampled_data):
"""Stub method for reverse transforming data patterns."""
return sampled_data
def sample(self, scale=1.0):
"""Generate synthetic data for the entire dataset.
Args:
scale (float):
A float representing how much to scale the data by. If scale is set to ``1.0``,
this does not scale the sizes of the tables. If ``scale`` is greater than ``1.0``
create more rows than the original data by a factor of ``scale``.
If ``scale`` is lower than ``1.0`` create fewer rows by the factor of ``scale``
than the original tables. Defaults to ``1.0``.
"""
if not self._fitted:
raise SamplingError(
'This synthesizer has not been fitted. Please fit your synthesizer first before '
'sampling synthetic data.'
)
if type(scale) not in (float, int) or not scale > 0:
raise SynthesizerInputError(
f"Invalid parameter for 'scale' ({scale}). Please provide a number that is >0.0."
)
with self._set_temp_numpy_seed(), disable_single_table_logger():
sampled_data = self._sample(scale=scale)
sampled_data = self._reverse_transform_helper(sampled_data)
total_rows = 0
total_columns = 0
for table in sampled_data.values():
total_rows += len(table)
total_columns += len(table.columns)
table_columns = getattr(self, '_original_table_columns', {})
for table in sampled_data:
table_data = sampled_data[table][self.get_metadata().get_column_names(table)]
if table in table_columns:
if isinstance(table_columns[table], dict):
table_data = table_data.rename(columns=table_columns[table])
else:
table_data.columns = table_columns[table]
sampled_data[table] = table_data
SYNTHESIZER_LOGGER.info({
'EVENT': 'Sample',
'TIMESTAMP': datetime.datetime.now(),
'SYNTHESIZER CLASS NAME': self.__class__.__name__,
'SYNTHESIZER ID': self._synthesizer_id,
'TOTAL NUMBER OF TABLES': len(sampled_data),
'TOTAL NUMBER OF ROWS': total_rows,
'TOTAL NUMBER OF COLUMNS': total_columns,
})
return sampled_data
def get_learned_distributions(self, table_name):
"""Get the marginal distributions used by the ``GaussianCopula`` for a table.
Return a dictionary mapping the column names with the distribution name and the learned
parameters for those.
Args:
table_name (str):
Table name for which the parameters should be retrieved.
Returns:
dict:
Dictionary containing the distributions used or detected for each column and the
learned parameters for those.
"""
synthesizer = self._table_synthesizers[table_name]
if hasattr(synthesizer, 'get_learned_distributions'):
return synthesizer.get_learned_distributions()
raise SynthesizerInputError(
f"Learned distributions are not available for the '{table_name}' "
f"table because it uses the '{synthesizer.__class__.__name__}'."
)
def get_loss_values(self, table_name):
"""Get the loss values from a model for a table.
Return a pandas dataframe mapping of the loss values per epoch of GAN
based synthesizers
Args:
table_name (str):
Table name for which the parameters should be retrieved.
Returns:
pd.DataFrame:
Dataframe of loss values per epoch
"""
if table_name not in self._table_synthesizers:
raise ValueError(f"Table '{table_name}' is not present in the metadata.")
synthesizer = self._table_synthesizers[table_name]
if hasattr(synthesizer, 'get_loss_values'):
return synthesizer.get_loss_values()
raise SynthesizerInputError(
f"Loss values are not available for table '{table_name}' "
'because the table does not use a GAN-based model.'
)
def _validate_constraints_to_be_added(self, constraints):
for constraint_dict in constraints:
if 'table_name' not in constraint_dict.keys():
raise SynthesizerInputError(
"A constraint is missing required parameter 'table_name'. "
'Please add this parameter to your constraint definition.'
)
if constraint_dict['constraint_class'] == 'Unique':
raise SynthesizerInputError(
"The constraint class 'Unique' is not currently supported for multi-table"
' synthesizers. Please remove the constraint for this synthesizer.'
)
if self._fitted:
warnings.warn(
"For these constraints to take effect, please refit the synthesizer using 'fit'."
)
def add_constraints(self, constraints):
"""Add constraints to the synthesizer.
Args:
constraints (list):
List of constraints described as dictionaries in the following format:
* ``constraint_class``: Name of the constraint to apply.
* ``table_name``: Name of the table where to apply the constraint.
* ``constraint_parameters``: A dictionary with the constraint parameters.
Raises:
SynthesizerInputError:
Raises when the ``Unique`` constraint is passed.
"""
self._validate_constraints_to_be_added(constraints)
for constraint in constraints:
constraint = deepcopy(constraint)
synthesizer = self._table_synthesizers[constraint.pop('table_name')]
synthesizer._data_processor.add_constraints([constraint])
def get_constraints(self):
"""Get constraints of the synthesizer.
Returns:
list:
List of dictionaries describing the constraints of the synthesizer.
"""
constraints = []
for table_name, synthesizer in self._table_synthesizers.items():
for constraint in synthesizer.get_constraints():
constraint['table_name'] = table_name
constraints.append(constraint)
return constraints
def load_custom_constraint_classes(self, filepath, class_names):
"""Load a custom constraint class for each table's synthesizer.
Args:
filepath (str):
String representing the absolute or relative path to the python file where
the custom constraints are declared.
class_names (list):
A list of custom constraint classes to be imported.
"""
for synthesizer in self._table_synthesizers.values():
synthesizer.load_custom_constraint_classes(filepath, class_names)
def add_custom_constraint_class(self, class_object, class_name):
"""Add a custom constraint class for the synthesizer to use.
Args:
class_object (sdv.constraints.Constraint):
A custom constraint class object.
class_name (str):
The name to assign this custom constraint class. This will be the name to use
when writing a constraint dictionary for ``add_constraints``.
"""
for synthesizer in self._table_synthesizers.values():
synthesizer.add_custom_constraint_class(class_object, class_name)
def get_info(self):
"""Get dictionary with information regarding the synthesizer.
Return:
dict:
* ``class_name``: synthesizer class name
* ``creation_date``: date of creation
* ``is_fit``: whether or not the synthesizer has been fit
* ``last_fit_date``: date for the last time it was fit
* ``fitted_sdv_version``: version of sdv it was on when fitted
"""
info = {
'class_name': self.__class__.__name__,
'creation_date': self._creation_date,
'is_fit': self._fitted,
'last_fit_date': self._fitted_date,
'fitted_sdv_version': self._fitted_sdv_version,
}
if self._fitted_sdv_enterprise_version:
info['fitted_sdv_enterprise_version'] = self._fitted_sdv_enterprise_version
return info
def _validate_fit_before_save(self):
"""Validate that the synthesizer has been fitted before saving."""
if not self._fitted:
warnings.warn(
'You are saving a synthesizer that has not yet been fitted. You will not be able '
'to sample synthetic data without fitting. We recommend fitting the synthesizer '
'first and then saving.'
)
def save(self, filepath):
"""Save this instance to the given path using cloudpickle.
Args:
filepath (str):
Path where the instance will be serialized.
"""
self._validate_fit_before_save()
synthesizer_id = getattr(self, '_synthesizer_id', None)
SYNTHESIZER_LOGGER.info({
'EVENT': 'Save',
'TIMESTAMP': datetime.datetime.now(),
'SYNTHESIZER CLASS NAME': self.__class__.__name__,
'SYNTHESIZER ID': synthesizer_id,
})
with open(filepath, 'wb') as output:
cloudpickle.dump(self, output)
@classmethod
def load(cls, filepath):
"""Load a multi-table synthesizer from a given path.
Args:
filepath (str):
A string describing the filepath of your saved synthesizer.
Returns:
MultiTableSynthesizer:
The loaded synthesizer.
"""
with open(filepath, 'rb') as f:
try:
synthesizer = cloudpickle.load(f)
except RuntimeError as e:
err_msg = (
'Attempting to deserialize object on a CUDA device but '
'torch.cuda.is_available() is False. If you are running on a CPU-only machine,'
" please use torch.load with map_location=torch.device('cpu') "
'to map your storages to the CPU.'
)
if str(e) == err_msg:
raise SamplingError(
'This synthesizer was created on a machine with GPU but the current '
'machine is CPU-only. This feature is currently unsupported. We recommend'
' sampling on the same GPU-enabled machine.'
)
raise e
check_synthesizer_version(synthesizer)
check_sdv_versions_and_warn(synthesizer)
if getattr(synthesizer, '_synthesizer_id', None) is None:
synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer)
SYNTHESIZER_LOGGER.info({
'EVENT': 'Load',
'TIMESTAMP': datetime.datetime.now(),
'SYNTHESIZER CLASS NAME': synthesizer.__class__.__name__,
'SYNTHESIZER ID': synthesizer._synthesizer_id,
})
return synthesizer