Skip to content

Commit

Permalink
Add support for Gretel Amplify
Browse files Browse the repository at this point in the history
  • Loading branch information
pimlock committed Oct 25, 2022
1 parent ba27d10 commit c785ecb
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 11 deletions.
20 changes: 20 additions & 0 deletions notebooks/trainer-examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,26 @@
"model = trainer.Trainer.load()\n",
"model.generate(seed_df=seed_df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Use Gretel amplify to generate large amount of data (GBs)\n",
"\n",
"from gretel_trainer import trainer\n",
"from gretel_trainer.models import GretelAmplify\n",
"\n",
"dataset = \"https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/USAdultIncome5k.csv\"\n",
"\n",
"model_type = GretelAmplify()\n",
"\n",
"model = trainer.Trainer(model_type=model_type)\n",
"model.train(dataset)\n",
"model.generate()"
]
}
],
"metadata": {
Expand Down
48 changes: 37 additions & 11 deletions src/gretel_trainer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@


HIGH_COLUMN_THRESHOLD = 20
HIGH_RECORD_THRESHOLD = 50000
HIGH_RECORD_THRESHOLD = 50_000
LOW_COLUMN_THRESHOLD = 4
LOW_RECORD_THRESHOLD = 1000
LOW_RECORD_THRESHOLD = 1_000


def determine_best_model(df: pd.DataFrame):
def determine_best_model(df: pd.DataFrame) -> _BaseConfig:
row_count, column_count = df.shape

if row_count > HIGH_RECORD_THRESHOLD or column_count > HIGH_COLUMN_THRESHOLD:
Expand Down Expand Up @@ -53,12 +53,14 @@ def __init__(
self.max_header_clusters = max_header_clusters
self.enable_privacy_filters = enable_privacy_filters

self._handle_privacy_filters()
self.validate()

def _handle_privacy_filters(self):
if not self.enable_privacy_filters:
logging.warning("Privacy filters disabled. Enable with the `enable_privacy_filters` param.")
self.update_params({"outliers": None, "similarity": None})

self.validate()

def update_params(self, params: dict):
"""Convenience function to update model specific parameters from the base config by key value.
Expand Down Expand Up @@ -95,13 +97,13 @@ def _replace_nested_key(self, data, key, value) -> dict:
class GretelLSTM(_BaseConfig):

_max_header_clusters_limit: int = 30
_max_rows_limit: int = 5000000
_max_rows_limit: int = 5_000_000
_model_slug: str = "synthetics"

def __init__(
self,
config="synthetics/default",
max_rows=50000,
max_rows=50_000,
max_header_clusters=20,
enable_privacy_filters=False,
):
Expand All @@ -115,14 +117,14 @@ def __init__(

class GretelCTGAN(_BaseConfig):

_max_header_clusters_limit: int = 1000
_max_rows_limit: int = 5000000
_max_header_clusters_limit: int = 1_000
_max_rows_limit: int = 5_000_000
_model_slug: str = "ctgan"

def __init__(
self,
config="synthetics/high-dimensionality",
max_rows=50000,
max_rows=50_000,
max_header_clusters=500,
enable_privacy_filters=False,
):
Expand All @@ -131,4 +133,28 @@ def __init__(
max_rows=max_rows,
max_header_clusters=max_header_clusters,
enable_privacy_filters=enable_privacy_filters,
)
)


class GretelAmplify(_BaseConfig):

_max_header_clusters_limit: int = 1_000
_max_rows_limit: int = 1_000_000_000
_model_slug: str = "amplify"

def __init__(
self,
config="synthetics/amplify",
max_rows=50_000,
max_header_clusters=500,
):
super().__init__(
config=config,
max_rows=max_rows,
max_header_clusters=max_header_clusters,
enable_privacy_filters=False,
)

def _handle_privacy_filters(self) -> None:
# Currently amplify doesn't support privacy filtering
pass
1 change: 1 addition & 0 deletions src/gretel_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os.path
from pathlib import Path
from typing import Optional

import pandas as pd
from gretel_client import configure_session
Expand Down

0 comments on commit c785ecb

Please sign in to comment.