diff --git a/notebooks/trainer-examples.ipynb b/notebooks/trainer-examples.ipynb index 2f8d0830..bfd4d566 100644 --- a/notebooks/trainer-examples.ipynb +++ b/notebooks/trainer-examples.ipynb @@ -109,6 +109,26 @@ "model = trainer.Trainer.load()\n", "model.generate(seed_df=seed_df)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use Gretel amplify to generate large amount of data (GBs)\n", + "\n", + "from gretel_trainer import trainer\n", + "from gretel_trainer.models import GretelAmplify\n", + "\n", + "dataset = \"https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/USAdultIncome5k.csv\"\n", + "\n", + "model_type = GretelAmplify()\n", + "\n", + "model = trainer.Trainer(model_type=model_type)\n", + "model.train(dataset)\n", + "model.generate()" + ] } ], "metadata": { diff --git a/src/gretel_trainer/models.py b/src/gretel_trainer/models.py index 5676c60e..7d9d0b4d 100644 --- a/src/gretel_trainer/models.py +++ b/src/gretel_trainer/models.py @@ -10,12 +10,12 @@ HIGH_COLUMN_THRESHOLD = 20 -HIGH_RECORD_THRESHOLD = 50000 +HIGH_RECORD_THRESHOLD = 50_000 LOW_COLUMN_THRESHOLD = 4 -LOW_RECORD_THRESHOLD = 1000 +LOW_RECORD_THRESHOLD = 1_000 -def determine_best_model(df: pd.DataFrame): +def determine_best_model(df: pd.DataFrame) -> _BaseConfig: row_count, column_count = df.shape if row_count > HIGH_RECORD_THRESHOLD or column_count > HIGH_COLUMN_THRESHOLD: @@ -53,12 +53,14 @@ def __init__( self.max_header_clusters = max_header_clusters self.enable_privacy_filters = enable_privacy_filters + self._handle_privacy_filters() + self.validate() + + def _handle_privacy_filters(self): if not self.enable_privacy_filters: logging.warning("Privacy filters disabled. Enable with the `enable_privacy_filters` param.") self.update_params({"outliers": None, "similarity": None}) - self.validate() - def update_params(self, params: dict): """Convenience function to update model specific parameters from the base config by key value. @@ -95,13 +97,13 @@ def _replace_nested_key(self, data, key, value) -> dict: class GretelLSTM(_BaseConfig): _max_header_clusters_limit: int = 30 - _max_rows_limit: int = 5000000 + _max_rows_limit: int = 5_000_000 _model_slug: str = "synthetics" def __init__( self, config="synthetics/default", - max_rows=50000, + max_rows=50_000, max_header_clusters=20, enable_privacy_filters=False, ): @@ -115,14 +117,14 @@ def __init__( class GretelCTGAN(_BaseConfig): - _max_header_clusters_limit: int = 1000 - _max_rows_limit: int = 5000000 + _max_header_clusters_limit: int = 1_000 + _max_rows_limit: int = 5_000_000 _model_slug: str = "ctgan" def __init__( self, config="synthetics/high-dimensionality", - max_rows=50000, + max_rows=50_000, max_header_clusters=500, enable_privacy_filters=False, ): @@ -131,4 +133,28 @@ def __init__( max_rows=max_rows, max_header_clusters=max_header_clusters, enable_privacy_filters=enable_privacy_filters, - ) \ No newline at end of file + ) + + +class GretelAmplify(_BaseConfig): + + _max_header_clusters_limit: int = 1_000 + _max_rows_limit: int = 1_000_000_000 + _model_slug: str = "amplify" + + def __init__( + self, + config="synthetics/amplify", + max_rows=50_000, + max_header_clusters=500, + ): + super().__init__( + config=config, + max_rows=max_rows, + max_header_clusters=max_header_clusters, + enable_privacy_filters=False, + ) + + def _handle_privacy_filters(self) -> None: + # Currently amplify doesn't support privacy filtering + pass diff --git a/src/gretel_trainer/trainer.py b/src/gretel_trainer/trainer.py index eb0106ab..28f3ba0c 100644 --- a/src/gretel_trainer/trainer.py +++ b/src/gretel_trainer/trainer.py @@ -4,6 +4,7 @@ import logging import os.path from pathlib import Path +from typing import Optional import pandas as pd from gretel_client import configure_session