Skip to content

Commit

Permalink
Add support for Gretel Amplify
Browse files Browse the repository at this point in the history
  • Loading branch information
pimlock authored Nov 2, 2022
2 parents a50a38d + ccd17d5 commit 13f4e8f
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 14 deletions.
20 changes: 20 additions & 0 deletions notebooks/trainer-examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,26 @@
"model = trainer.Trainer.load()\n",
"model.generate(seed_df=seed_df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Use Gretel amplify to generate large amount of data (GBs)\n",
"\n",
"from gretel_trainer import trainer\n",
"from gretel_trainer.models import GretelAmplify\n",
"\n",
"dataset = \"https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/USAdultIncome5k.csv\"\n",
"\n",
"model_type = GretelAmplify()\n",
"\n",
"model = trainer.Trainer(model_type=model_type)\n",
"model.train(dataset)\n",
"model.generate()"
]
}
],
"metadata": {
Expand Down
63 changes: 49 additions & 14 deletions src/gretel_trainer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@


HIGH_COLUMN_THRESHOLD = 20
HIGH_RECORD_THRESHOLD = 50000
HIGH_RECORD_THRESHOLD = 50_000
LOW_COLUMN_THRESHOLD = 4
LOW_RECORD_THRESHOLD = 1000
LOW_RECORD_THRESHOLD = 1_000


def determine_best_model(df: pd.DataFrame):
def determine_best_model(df: pd.DataFrame) -> _BaseConfig:
"""
Determine the Gretel model best suited for generating synthetic data
for your dataset.
for your dataset.
Args:
df (pd.DataFrame): Pandas DataFrame containing the data used to train a synthetic model.
Expand Down Expand Up @@ -63,12 +63,14 @@ def __init__(
self.max_header_clusters = max_header_clusters
self.enable_privacy_filters = enable_privacy_filters

self._handle_privacy_filters()
self.validate()

def _handle_privacy_filters(self):
if not self.enable_privacy_filters:
logging.warning("Privacy filters disabled. Enable with the `enable_privacy_filters` param.")
self.update_params({"outliers": None, "similarity": None})

self.validate()

def update_params(self, params: dict):
"""Convenience function to update model specific parameters from the base config by key value.
Expand Down Expand Up @@ -115,13 +117,13 @@ class GretelLSTM(_BaseConfig):
enable_privacy_filters (bool, optional): Default: False
"""
_max_header_clusters_limit: int = 30
_max_rows_limit: int = 5000000
_max_rows_limit: int = 5_000_000
_model_slug: str = "synthetics"

def __init__(
self,
config="synthetics/default",
max_rows=50000,
max_rows=50_000,
max_header_clusters=20,
enable_privacy_filters=False,
):
Expand All @@ -135,24 +137,24 @@ def __init__(

class GretelCTGAN(_BaseConfig):
"""
This model works well for high dimensional, largely numeric data. Use for datasets with more than 20 columns and/or 50,000 rows.
This model works well for high dimensional, largely numeric data. Use for datasets with more than 20 columns and/or 50,000 rows.
Not ideal if dataset contains free text field
Args:
config (str/dict, optional): Either a string representing the path to the config on the local filesystem, a string representing a path to the default Gretel configurations, or a dictionary containing the configurations. Default: "synthetics/default", a default Gretel configuration
config (str/dict, optional): Either a string representing the path to the config on the local filesystem, a string representing a path to the default Gretel configurations, or a dictionary containing the configurations. Default: "synthetics/high-dimensionality", a default Gretel configuration
max_rows (int, optional): The number of rows of synthetic data to generate. Defaults to 50000
max_header_clusters (int, optional): Default: 20
max_header_clusters (int, optional): Default: 500
enable_privacy_filters (bool, optional): Default: False
"""
_max_header_clusters_limit: int = 1000
_max_rows_limit: int = 5000000
_max_header_clusters_limit: int = 1_000
_max_rows_limit: int = 5_000_000
_model_slug: str = "ctgan"

def __init__(
self,
config="synthetics/high-dimensionality",
max_rows=50000,
max_rows=50_000,
max_header_clusters=500,
enable_privacy_filters=False,
):
Expand All @@ -162,3 +164,36 @@ def __init__(
max_header_clusters=max_header_clusters,
enable_privacy_filters=enable_privacy_filters,
)


class GretelAmplify(_BaseConfig):
"""
This model is able to generate large quantities of data from real-world data or synthetic data.
Note: this model doesn't currently support privacy filtering.
Args:
config (str/dict, optional): Either a string representing the path to the config on the local filesystem, a string representing a path to the default Gretel configurations, or a dictionary containing the configurations. Default: "synthetics/amplify", a default Gretel configuration for Amplify.
max_rows (int, optional): The number of rows of synthetic data to generate. Defaults to 50000
max_header_clusters (int, optional): Default: 50
"""
_max_header_clusters_limit: int = 1_000
_max_rows_limit: int = 1_000_000_000
_model_slug: str = "amplify"

def __init__(
self,
config="synthetics/amplify",
max_rows=50_000,
max_header_clusters=500,
):
super().__init__(
config=config,
max_rows=max_rows,
max_header_clusters=max_header_clusters,
enable_privacy_filters=False,
)

def _handle_privacy_filters(self) -> None:
# Currently amplify doesn't support privacy filtering
pass
1 change: 1 addition & 0 deletions src/gretel_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os.path
from pathlib import Path
from typing import Optional

import pandas as pd
from gretel_client import configure_session
Expand Down

0 comments on commit 13f4e8f

Please sign in to comment.