From 82eaf3b2a63204e217287c3174e5fc9448d5301a Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Mon, 24 Oct 2022 18:06:05 -0700 Subject: [PATCH 1/2] Add support for Gretel Amplify --- notebooks/trainer-examples.ipynb | 20 +++++++++++++ src/gretel_trainer/models.py | 51 ++++++++++++++++++++++++-------- src/gretel_trainer/trainer.py | 1 + 3 files changed, 60 insertions(+), 12 deletions(-) diff --git a/notebooks/trainer-examples.ipynb b/notebooks/trainer-examples.ipynb index 2f8d0830..bfd4d566 100644 --- a/notebooks/trainer-examples.ipynb +++ b/notebooks/trainer-examples.ipynb @@ -109,6 +109,26 @@ "model = trainer.Trainer.load()\n", "model.generate(seed_df=seed_df)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use Gretel amplify to generate large amount of data (GBs)\n", + "\n", + "from gretel_trainer import trainer\n", + "from gretel_trainer.models import GretelAmplify\n", + "\n", + "dataset = \"https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/USAdultIncome5k.csv\"\n", + "\n", + "model_type = GretelAmplify()\n", + "\n", + "model = trainer.Trainer(model_type=model_type)\n", + "model.train(dataset)\n", + "model.generate()" + ] } ], "metadata": { diff --git a/src/gretel_trainer/models.py b/src/gretel_trainer/models.py index 4ae7dc0f..db263edb 100644 --- a/src/gretel_trainer/models.py +++ b/src/gretel_trainer/models.py @@ -10,15 +10,15 @@ HIGH_COLUMN_THRESHOLD = 20 -HIGH_RECORD_THRESHOLD = 50000 +HIGH_RECORD_THRESHOLD = 50_000 LOW_COLUMN_THRESHOLD = 4 -LOW_RECORD_THRESHOLD = 1000 +LOW_RECORD_THRESHOLD = 1_000 -def determine_best_model(df: pd.DataFrame): +def determine_best_model(df: pd.DataFrame) -> _BaseConfig: """ Determine the Gretel model best suited for generating synthetic data - for your dataset. + for your dataset. Args: df (pd.DataFrame): Pandas DataFrame containing the data used to train a synthetic model. @@ -63,12 +63,14 @@ def __init__( self.max_header_clusters = max_header_clusters self.enable_privacy_filters = enable_privacy_filters + self._handle_privacy_filters() + self.validate() + + def _handle_privacy_filters(self): if not self.enable_privacy_filters: logging.warning("Privacy filters disabled. Enable with the `enable_privacy_filters` param.") self.update_params({"outliers": None, "similarity": None}) - self.validate() - def update_params(self, params: dict): """Convenience function to update model specific parameters from the base config by key value. @@ -115,13 +117,13 @@ class GretelLSTM(_BaseConfig): enable_privacy_filters (bool, optional): Default: False """ _max_header_clusters_limit: int = 30 - _max_rows_limit: int = 5000000 + _max_rows_limit: int = 5_000_000 _model_slug: str = "synthetics" def __init__( self, config="synthetics/default", - max_rows=50000, + max_rows=50_000, max_header_clusters=20, enable_privacy_filters=False, ): @@ -135,7 +137,7 @@ def __init__( class GretelCTGAN(_BaseConfig): """ - This model works well for high dimensional, largely numeric data. Use for datasets with more than 20 columns and/or 50,000 rows. + This model works well for high dimensional, largely numeric data. Use for datasets with more than 20 columns and/or 50,000 rows. Not ideal if dataset contains free text field @@ -145,14 +147,14 @@ class GretelCTGAN(_BaseConfig): max_header_clusters (int, optional): Default: 20 enable_privacy_filters (bool, optional): Default: False """ - _max_header_clusters_limit: int = 1000 - _max_rows_limit: int = 5000000 + _max_header_clusters_limit: int = 1_000 + _max_rows_limit: int = 5_000_000 _model_slug: str = "ctgan" def __init__( self, config="synthetics/high-dimensionality", - max_rows=50000, + max_rows=50_000, max_header_clusters=500, enable_privacy_filters=False, ): @@ -162,3 +164,28 @@ def __init__( max_header_clusters=max_header_clusters, enable_privacy_filters=enable_privacy_filters, ) + + +class GretelAmplify(_BaseConfig): + + _max_header_clusters_limit: int = 1_000 + _max_rows_limit: int = 1_000_000_000 + _model_slug: str = "amplify" + + def __init__( + self, + config="synthetics/amplify", + max_rows=50_000, + max_header_clusters=500, + ): + super().__init__( + config=config, + max_rows=max_rows, + max_header_clusters=max_header_clusters, + enable_privacy_filters=False, + ) + + def _handle_privacy_filters(self) -> None: + # Currently amplify doesn't support privacy filtering + pass + diff --git a/src/gretel_trainer/trainer.py b/src/gretel_trainer/trainer.py index eb0106ab..28f3ba0c 100644 --- a/src/gretel_trainer/trainer.py +++ b/src/gretel_trainer/trainer.py @@ -4,6 +4,7 @@ import logging import os.path from pathlib import Path +from typing import Optional import pandas as pd from gretel_client import configure_session From ccd17d55ae920139cdaf3001d292fd75c801fd48 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Tue, 1 Nov 2022 13:03:34 -0700 Subject: [PATCH 2/2] Add docs. --- src/gretel_trainer/models.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/gretel_trainer/models.py b/src/gretel_trainer/models.py index db263edb..c114d011 100644 --- a/src/gretel_trainer/models.py +++ b/src/gretel_trainer/models.py @@ -142,9 +142,9 @@ class GretelCTGAN(_BaseConfig): Not ideal if dataset contains free text field Args: - config (str/dict, optional): Either a string representing the path to the config on the local filesystem, a string representing a path to the default Gretel configurations, or a dictionary containing the configurations. Default: "synthetics/default", a default Gretel configuration + config (str/dict, optional): Either a string representing the path to the config on the local filesystem, a string representing a path to the default Gretel configurations, or a dictionary containing the configurations. Default: "synthetics/high-dimensionality", a default Gretel configuration max_rows (int, optional): The number of rows of synthetic data to generate. Defaults to 50000 - max_header_clusters (int, optional): Default: 20 + max_header_clusters (int, optional): Default: 500 enable_privacy_filters (bool, optional): Default: False """ _max_header_clusters_limit: int = 1_000 @@ -167,7 +167,16 @@ def __init__( class GretelAmplify(_BaseConfig): + """ + This model is able to generate large quantities of data from real-world data or synthetic data. + + Note: this model doesn't currently support privacy filtering. + Args: + config (str/dict, optional): Either a string representing the path to the config on the local filesystem, a string representing a path to the default Gretel configurations, or a dictionary containing the configurations. Default: "synthetics/amplify", a default Gretel configuration for Amplify. + max_rows (int, optional): The number of rows of synthetic data to generate. Defaults to 50000 + max_header_clusters (int, optional): Default: 50 + """ _max_header_clusters_limit: int = 1_000 _max_rows_limit: int = 1_000_000_000 _model_slug: str = "amplify" @@ -188,4 +197,3 @@ def __init__( def _handle_privacy_filters(self) -> None: # Currently amplify doesn't support privacy filtering pass -