From 69eef18299505d03d44b0b59d8de35c5e77be2f9 Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Fri, 3 Jun 2022 10:22:43 -0700 Subject: [PATCH 01/20] trainer module initial check-in --- notebooks/trainer-example.py | 8 ++ setup.py | 2 +- src/{trainer => gretel_trainer}/__init__.py | 0 src/{trainer => gretel_trainer}/runner.py | 117 ++++++++++----- src/{trainer => gretel_trainer}/strategy.py | 18 ++- .../templates/gretel_ctgan.yaml | 12 ++ src/gretel_trainer/templates/gretel_lstm.yaml | 32 +++++ src/gretel_trainer/trainer.py | 136 ++++++++++++++++++ 8 files changed, 279 insertions(+), 46 deletions(-) create mode 100644 notebooks/trainer-example.py rename src/{trainer => gretel_trainer}/__init__.py (100%) rename src/{trainer => gretel_trainer}/runner.py (89%) rename src/{trainer => gretel_trainer}/strategy.py (90%) create mode 100644 src/gretel_trainer/templates/gretel_ctgan.yaml create mode 100644 src/gretel_trainer/templates/gretel_lstm.yaml create mode 100644 src/gretel_trainer/trainer.py diff --git a/notebooks/trainer-example.py b/notebooks/trainer-example.py new file mode 100644 index 00000000..d7114a70 --- /dev/null +++ b/notebooks/trainer-example.py @@ -0,0 +1,8 @@ +from gretel_trainer import trainer + +dataset_path = "https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/USAdultIncome5k.csv" + +model = trainer.Trainer(model_type="GretelLSTM") +model.train(dataset_path) +print(model.generate()) + diff --git a/setup.py b/setup.py index acd8be69..9f48f5a4 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ local_path = pathlib.Path(__file__).parent install_requires = (local_path / "requirements.txt").read_text().splitlines() -setup(name="trainer", +setup(name="gretel-trainer", version="0.0.1", package_dir={'': 'src'}, install_requires=install_requires, diff --git a/src/trainer/__init__.py b/src/gretel_trainer/__init__.py similarity index 100% rename from src/trainer/__init__.py rename to src/gretel_trainer/__init__.py diff --git a/src/trainer/runner.py b/src/gretel_trainer/runner.py similarity index 89% rename from src/trainer/runner.py rename to src/gretel_trainer/runner.py index 6629b127..8f4720b6 100644 --- a/src/trainer/runner.py +++ b/src/gretel_trainer/runner.py @@ -30,13 +30,15 @@ from copy import deepcopy from dataclasses import dataclass from functools import wraps +import json import logging from pathlib import Path from typing import List, Optional, Union +import smart_open import pandas as pd -from trainer.strategy import Partition, PartitionConstraints, PartitionStrategy +from gretel_trainer.strategy import Partition, PartitionConstraints, PartitionStrategy from gretel_client.projects import Project from gretel_client.projects.jobs import ACTIVE_STATES @@ -98,7 +100,7 @@ def _remote_dataframe_fetcher(payload: RemoteDFPayload) -> RemoteDFPayload: # We need the model object no matter what model = Model(payload.project, model_id=payload.uid) job = model - + # if we are downloading handler data, we reset our job # to the specific handler object if payload.job_type == "run": @@ -109,16 +111,21 @@ def _remote_dataframe_fetcher(payload: RemoteDFPayload) -> RemoteDFPayload: return payload -def _maybe_submit_job(job: Union[Model, RecordHandler]) -> Optional[Union[Model, RecordHandler]]: +def _maybe_submit_job( + job: Union[Model, RecordHandler] +) -> Optional[Union[Model, RecordHandler]]: try: job = job.submit_cloud() except ApiException as err: if "Maximum number of" in str(err): - logger.warning("Rate limiting: Max jobs created, skipping new job for now...") + logger.warning( + "Rate limiting: Max jobs created, skipping new job for now..." + ) return None - + return job + class StrategyRunner: _df: pd.DataFrame @@ -186,7 +193,9 @@ def load(self): self._loaded = True @classmethod - def from_completed(cls, project: Project, cache_file: Union[str, Path]) -> StrategyRunner: + def from_completed( + cls, project: Project, cache_file: Union[str, Path] + ) -> StrategyRunner: cache_file = Path(cache_file) if not cache_file.exists(): raise ValueError("cache file does not exist") @@ -235,6 +244,13 @@ def _update_job_status(self): if current_model.status == Status.COMPLETED: report = current_model.peek_report() + + if report is None: + with smart_open.open( + current_model.get_artifact_link("report_json") + ) as fin: + report = json.loads(fin.read()) + sqs = report["synthetic_data_quality_score"]["score"] label = "Moderate" if sqs >= 80: @@ -248,7 +264,7 @@ def _update_job_status(self): ) _update.update({SQS: report}) - + partition.update_ctx(_update) self._strategy.status_counter = dict(self._status_counter) # Aggressive, but save after every update @@ -277,7 +293,8 @@ def _update_handler_status(self): # Hydrate a Model and Handler object from the remote API current_model = Model(self._project, model_id=model_id) - current_handler = RecordHandler(model=current_model, record_id=handler_id) + current_handler = RecordHandler( + model=current_model, record_id=handler_id) self._handler_status_counter.update([current_handler.status]) @@ -331,7 +348,8 @@ def _remove_unused_artifact(self) -> Optional[str]: continue if artifact_key: - logger.debug(f"Attempting to remove artifact: {p.ctx.get(ARTIFACT)}") + logger.debug( + f"Attempting to remove artifact: {p.ctx.get(ARTIFACT)}") self._project.delete_artifact(artifact_key) p.update_ctx({ARTIFACT: None}) self._strategy.save() @@ -370,7 +388,8 @@ def _remove_unused_artifact(self) -> Optional[str]: def _partition_to_artifact(self, partition: Partition) -> Optional[ArtifactResult]: removed_artifact = self._remove_unused_artifact() if not removed_artifact: - logger.debug("Could not make room for next data set, waiting for room...") + logger.debug( + "Could not make room for next data set, waiting for room...") # We couldn't make room so we don't try and upload the next artifact return None @@ -389,27 +408,36 @@ def _df_to_artifact(self, df: pd.DataFrame, filename: str) -> ArtifactResult: return ArtifactResult(id=artifact_id, record_count=len(df)) @_needs_load - def train_partition(self, partition: Partition, artifact: ArtifactResult) -> Optional[str]: + def train_partition( + self, partition: Partition, artifact: ArtifactResult + ) -> Optional[str]: attempt = partition.ctx.get(ATTEMPT, 0) + 1 model_config = deepcopy(self._model_config) - model_config["models"][0]["synthetics"]["generate"] = { - "num_records": artifact.record_count, - "max_invalid": None, - } + data_source = None - # If we're trying this model for a second+ time, we reduce the vocab size to - # utilize the char encoder in order to give a better chance and success - if attempt > 1: - model_config["models"][0]["synthetics"]["params"]["vocab_size"] = 0 - - # If this partition is for the first-N headers and we have known seed headers, we have to - # modify the configuration to account for the seed task. - if partition.columns.seed_headers: - model_config["models"][0]["synthetics"]["task"] = { - "type": "seed", - "attrs": {"fields": partition.columns.seed_headers}, + if "synthetics" in model_config["models"][0].keys(): + + model_config["models"][0]["synthetics"]["generate"] = { + "num_records": artifact.record_count, + "max_invalid": None, } + # If we're trying this model for a second+ time, we reduce the vocab size to + # utilize the char encoder in order to give a better chance and success + if attempt > 1: + model_config["models"][0]["synthetics"]["params"]["vocab_size"] = 0 + + # If this partition is for the first-N headers and we have known seed headers, we have to + # modify the configuration to account for the seed task. + if partition.columns.seed_headers: + model_config["models"][0]["synthetics"]["task"] = { + "type": "seed", + "attrs": {"fields": partition.columns.seed_headers}, + } + + elif "ctgan" in model_config["models"][0].keys(): + pass + model = self._project.create_model_obj( model_config=model_config, data_source=artifact.id ) @@ -428,13 +456,14 @@ def train_partition(self, partition: Partition, artifact: ArtifactResult) -> Opt ) self._strategy.save() logger.info( - f"Started model: {model.print_obj['model_name']} " - f"source: {model.print_obj['config']['models'][0]['synthetics']['data_source']}" + f"Started model: {model.print_obj['model_name']} " f"source: {artifact.id}" ) return model.model_id @_needs_load - def run_partition(self, partition: Partition, gen_payload: GenPayload) -> Optional[str]: + def run_partition( + self, partition: Partition, gen_payload: GenPayload + ) -> Optional[str]: """ Run a record handler for a model and return the job id. @@ -456,7 +485,7 @@ def run_partition(self, partition: Partition, gen_payload: GenPayload) -> Option params={ "num_records": gen_payload.num_records, "max_invalid": gen_payload.max_invalid, - } + }, ) handler_obj = _maybe_submit_job(handler_obj) if handler_obj is None: @@ -517,7 +546,10 @@ def run_next_partition(self, gen_payload: GenPayload) -> Optional[str]: logger.info(f"Generating data for partition {partition.idx}") start_job = True - elif status in (Status.ERROR, Status.LOST) and attempt_count < self._error_retry_limit: + elif ( + status in (Status.ERROR, Status.LOST) + and attempt_count < self._error_retry_limit + ): logger.info( f"Partition {partition.idx} has status {status.value}, re-attempting generation" ) @@ -531,13 +563,17 @@ def run_next_partition(self, gen_payload: GenPayload) -> Optional[str]: gen_payload.seed_df, pd.DataFrame ): # NOTE(jm): If we've tried N-1 attempts with seeds and the handler has continued - # fail then we should not use seeds to at least let the handler try to succeed. + # fail then we should not use seeds to at least let the handler try to succeed. # One example of this happening would be when a partition's model receives seeds # where the values of the seeds were not in the training set (due to partitioning). if attempt_count == self._error_retry_limit - 1: - logger.info(f"WARNING: Disabling seeds for partition {partition.idx} due to previous failed generation attempts...") + logger.info( + f"WARNING: Disabling seeds for partition {partition.idx} due to previous failed generation attempts..." + ) else: - logger.info("Partition has seed fields, uploading seed artifact...") + logger.info( + "Partition has seed fields, uploading seed artifact..." + ) use_seeds = True removed_artifact = self._remove_unused_artifact() if removed_artifact is None: @@ -547,7 +583,8 @@ def run_next_partition(self, gen_payload: GenPayload) -> Optional[str]: return None filename = f"{self.strategy_id}-seeds-{partition.idx}.csv" - artifact = self._df_to_artifact(gen_payload.seed_df, filename) + artifact = self._df_to_artifact( + gen_payload.seed_df, filename) new_payload = GenPayload( num_records=gen_payload.num_records, @@ -628,7 +665,8 @@ def _get_synthetic_data(self, job_type: str, artifact_type: str) -> pd.DataFrame num_completed = self._status_counter.get(Status.COMPLETED, 0) elif job_type == "run": self._update_handler_status() - num_completed = self._handler_status_counter.get(Status.COMPLETED, 0) + num_completed = self._handler_status_counter.get( + Status.COMPLETED, 0) else: raise ValueError("invalid job_type") @@ -697,10 +735,11 @@ def generate_data( seed_df: Optional[pd.DataFrame] = None, num_records: Optional[int] = None, max_invalid: Optional[int] = None, - clear_cache: bool = False + clear_cache: bool = False, ): if seed_df is None and not num_records: - raise ValueError("must provide a seed_df or num_records to generate") + raise ValueError( + "must provide a seed_df or num_records to generate") if isinstance(seed_df, pd.DataFrame) and num_records: raise ValueError("must use one of seed_df or num_records only") @@ -769,7 +808,7 @@ def generate_data( if handler_id is not None: # Go around and try again if this succeeded continue - + # Catch a 4xx in the event we are at capacity, or something else goes wrong except Exception as err: logger.warning(f"Error running model: {str(err)}") diff --git a/src/trainer/strategy.py b/src/gretel_trainer/strategy.py similarity index 90% rename from src/trainer/strategy.py rename to src/gretel_trainer/strategy.py index 72ba50b8..cf7fc1e4 100644 --- a/src/trainer/strategy.py +++ b/src/gretel_trainer/strategy.py @@ -32,7 +32,7 @@ def extract_df(self, df: pd.DataFrame) -> pd.DataFrame: if self.columns is not None: df = df[self.columns.headers] - return df.iloc[self.rows.start : self.rows.end] # noqa + return df.iloc[self.rows.start: self.rows.end] # noqa def update_ctx(self, update: dict): self.ctx.update(update) @@ -47,10 +47,12 @@ class PartitionConstraints: def __post_init__(self): if self.max_row_count is not None and self.max_row_partitions is not None: - raise AttributeError("cannot use both max_row_count and max_row_partitions") + raise AttributeError( + "cannot use both max_row_count and max_row_partitions") if self.max_row_count is None and self.max_row_partitions is None: - raise AttributeError("must use one of max_row_count or max_row_partitions") + raise AttributeError( + "must use one of max_row_count or max_row_partitions") @property def header_cluster_count(self) -> int: @@ -82,7 +84,9 @@ def _build_partitions( partitions.append( Partition( rows=RowPartition(start=next_start, end=next_end), - columns=ColumnPartition(headers=header_cluster, idx=idx, seed_headers=seed_headers), + columns=ColumnPartition( + headers=header_cluster, idx=idx, seed_headers=seed_headers + ), idx=partition_idx, ) ) @@ -114,7 +118,9 @@ def _build_partitions( rows=RowPartition( start=curr_start, end=curr_start + chunk_size ), - columns=ColumnPartition(headers=header_cluster, idx=idx, seed_headers=seed_headers), + columns=ColumnPartition( + headers=header_cluster, idx=idx, seed_headers=seed_headers + ), idx=partition_idx, ) ) @@ -141,7 +147,7 @@ def from_dataframe( id=id, partitions=partitions, header_cluster_count=constraints.header_cluster_count, - original_headers=list(df) + original_headers=list(df), ) @classmethod diff --git a/src/gretel_trainer/templates/gretel_ctgan.yaml b/src/gretel_trainer/templates/gretel_ctgan.yaml new file mode 100644 index 00000000..db9eaeb1 --- /dev/null +++ b/src/gretel_trainer/templates/gretel_ctgan.yaml @@ -0,0 +1,12 @@ +schema_version: '1.0' + +models: +- ctgan: + data_source: __tmp__ + params: + batch_size: 200 + discriminator_dim: !!python/tuple [256, 256] + discriminator_lr: 0.00033 + epochs: 50 + generator_dim: !!python/tuple [256, 256] + generator_lr: 1.0e-05 diff --git a/src/gretel_trainer/templates/gretel_lstm.yaml b/src/gretel_trainer/templates/gretel_lstm.yaml new file mode 100644 index 00000000..9c8efcb3 --- /dev/null +++ b/src/gretel_trainer/templates/gretel_lstm.yaml @@ -0,0 +1,32 @@ +schema_version: "1.0" + +models: + - synthetics: + data_source: __tmp__ + params: + epochs: 100 + batch_size: 64 + vocab_size: 20000 + reset_states: False + learning_rate: 0.01 + rnn_units: 256 + dropout_rate: 0.2 + overwrite: True + early_stopping: True + gen_temp: 1.0 + predict_batch_size: 64 + validation_split: False + dp: False + dp_noise_multiplier: 0.001 + dp_l2_norm_clip: 5.0 + dp_microbatches: 1 + data_upsample_limit: 10000 + validators: + in_set_count: 10 + pattern_count: 10 + generate: + num_records: 5000 + max_invalid: null + privacy_filters: + outliers: medium + similarity: medium diff --git a/src/gretel_trainer/trainer.py b/src/gretel_trainer/trainer.py new file mode 100644 index 00000000..2a1e27d1 --- /dev/null +++ b/src/gretel_trainer/trainer.py @@ -0,0 +1,136 @@ +"""Main Trainer Module""" + +from collections import namedtuple +from enum import Enum +import logging +import pandas as pd +import pkgutil +import yaml + +from gretel_client import configure_session, ClientConfig +from gretel_client.projects import create_or_get_unique_project +from gretel_client.projects.models import read_model_config +from gretel_client.projects.jobs import Status +from gretel_synthetics.utils.header_clusters import cluster + +from gretel_trainer import strategy, runner + + +logger = logging.getLogger(__name__) +logger.setLevel("DEBUG") + + +class ExtendedEnum(Enum): + """Utility class for Model enum""" + + @classmethod + def get_types(cls): + return list(map(lambda c: c.name, cls)) + + @classmethod + def get_config(cls, model: str): + return cls[model].config + + +class Model(namedtuple("Model", "config"), ExtendedEnum): + """Enum to pair valid models and configurations""" + + GretelLSTM = "templates/gretel_lstm.yaml" + GretelCTGAN = "templates/gretel_ctgan.yaml" + + +class Trainer: + f"""Automated generative training and sampling tool + + Args: + project_name (str, optional): Gretel project name. Defaults to "trainer". + max_header_clusters (int, optional): Max number of clusters per batch. Defaults to 20. + max_rows (int, optional): Max number of rows per batch. Defaults to 50000. + model_type (str, optional): Options include {Model.get_types()}. Defaults to "GretelLSTM". + """ + + def __init__( + self, + project_name: str = "trainer", + max_header_clusters: int = 20, + max_rows: int = 50000, + model_type: str = "GretelLSTM", + ): + + configure_session(api_key="prompt", cache="yes", validate=True) + + self.df = None + self.dataset_path = None + self.project_name = project_name + self.project = create_or_get_unique_project(name=project_name) + self.max_header_clusters = max_header_clusters + self.max_rows = max_rows + + if model_type in Model.get_types(): + config = pkgutil.get_data(__name__, Model.get_config(model_type)) + self.config = yaml.load(config, Loader=yaml.FullLoader) + logger.debug(self.config) + else: + raise ValueError( + f"Invalid model type. Must be {Model.get_model_types()}") + + def train(self, dataset_path: str, overwrite: bool = True, round_decimals: int = 4): + """Train a model on the dataset + + Args: + dataset_path (str): Path or URL to CSV + overwrite (bool, optional): Overwrite previous progress. Defaults to True. + round_decimals (int, optional): Round decimals in CSV as preprocessing step. Defaults to 4. + """ + self.dataset_path = dataset_path + self.df = self._preprocess_data( + dataset_path=dataset_path, round_decimals=round_decimals + ) + self.run = self._initialize_run(df=self.df, overwrite=overwrite) + self.run.train_all_partitions() + + def load(self): + """Load an existing strategy""" + self.run = self._initialize_run(overwrite=False) + + def generate(self, num_records: int = 500) -> pd.DataFrame: + """Generate synthetic data""" + self.run.generate_data(num_records=num_records, max_invalid=None) + return self.run.get_synthetic_data() + + def _preprocess_data( + self, dataset_path: str, round_decimals: int = 4 + ) -> pd.DataFrame: + """Preprocess input data""" + tmp = pd.read_csv(dataset_path, low_memory=False) + tmp = tmp.round(round_decimals) + return tmp + + def _initialize_run( + self, df: pd.DataFrame = pd.DataFrame(), overwrite: bool = True + ) -> runner.StrategyRunner: + """Create training jobs""" + constraints = None + + if not df.empty: + header_clusters = cluster( + df, maxsize=self.max_header_clusters, plot=False) + logger.info( + f"Header clustering created {len(header_clusters)} cluster(s) " + f"of length(s) {[len(x) for x in header_clusters]}" + ) + + constraints = strategy.PartitionConstraints( + header_clusters=header_clusters, max_row_count=self.max_rows + ) + + run = runner.StrategyRunner( + strategy_id=f"{self.project_name}", + df=self.df, + cache_file=f"{self.project_name}-runner.json", + cache_overwrite=overwrite, + model_config=self.config, + partition_constraints=constraints, + project=self.project, + ) + return run From f4f0591bd04ad55f3f44e69b6baef6f8965fface Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Fri, 3 Jun 2022 11:40:51 -0700 Subject: [PATCH 02/20] customize training params --- notebooks/trainer-example.py | 8 -------- notebooks/trainer-examples.py | 30 ++++++++++++++++++++++++++++++ src/gretel_trainer/trainer.py | 23 +++++++++++++++++++++-- 3 files changed, 51 insertions(+), 10 deletions(-) delete mode 100644 notebooks/trainer-example.py create mode 100644 notebooks/trainer-examples.py diff --git a/notebooks/trainer-example.py b/notebooks/trainer-example.py deleted file mode 100644 index d7114a70..00000000 --- a/notebooks/trainer-example.py +++ /dev/null @@ -1,8 +0,0 @@ -from gretel_trainer import trainer - -dataset_path = "https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/USAdultIncome5k.csv" - -model = trainer.Trainer(model_type="GretelLSTM") -model.train(dataset_path) -print(model.generate()) - diff --git a/notebooks/trainer-examples.py b/notebooks/trainer-examples.py new file mode 100644 index 00000000..064e2eb7 --- /dev/null +++ b/notebooks/trainer-examples.py @@ -0,0 +1,30 @@ +from gretel_trainer import trainer + +dataset = "https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/USAdultIncome5k.csv" + +# Simplest example +model = trainer.Trainer() +model.train(dataset) +df = model.generate() + +# Specify underlying model +#model = trainer.Trainer(model_type="GretelLSTM") +#model.train(dataset) +#df = model.generate() + +# Update trainer parameters +#model = trainer.Trainer(max_header_clusters=20, max_rows=50000) +#model.train(dataset) +#df = model.generate() + +# Specify synthetic model and update config params +#model = trainer.Trainer(model_type="GretelCTGAN", model_kwargs={'epochs': 20}) +#model.train(dataset) +#df = model.generate() + +# Load and generate data from an existing model +#model = trainer.Trainer() +#model.load() +#df = model.generate(num_records=500) + +print(df) \ No newline at end of file diff --git a/src/gretel_trainer/trainer.py b/src/gretel_trainer/trainer.py index 2a1e27d1..c0929f1a 100644 --- a/src/gretel_trainer/trainer.py +++ b/src/gretel_trainer/trainer.py @@ -40,13 +40,14 @@ class Model(namedtuple("Model", "config"), ExtendedEnum): class Trainer: - f"""Automated generative training and sampling tool + """Automated generative training and sampling tool Args: project_name (str, optional): Gretel project name. Defaults to "trainer". max_header_clusters (int, optional): Max number of clusters per batch. Defaults to 20. max_rows (int, optional): Max number of rows per batch. Defaults to 50000. - model_type (str, optional): Options include {Model.get_types()}. Defaults to "GretelLSTM". + model_type (str, optional): Options include ["GretelLSTM", "GretelCTGAN"]. Defaults to "GretelLSTM". + model_kwargs (dict, optional): Modify model configuration settings by key. E.g. {'epochs': 20} """ def __init__( @@ -55,6 +56,7 @@ def __init__( max_header_clusters: int = 20, max_rows: int = 50000, model_type: str = "GretelLSTM", + model_kwargs: dict = {} ): configure_session(api_key="prompt", cache="yes", validate=True) @@ -69,7 +71,12 @@ def __init__( if model_type in Model.get_types(): config = pkgutil.get_data(__name__, Model.get_config(model_type)) self.config = yaml.load(config, Loader=yaml.FullLoader) + + # Update default config settings with kwargs by key + for key, value in model_kwargs.items(): + self.config = self.replace_nested_key(self.config, key, value) logger.debug(self.config) + else: raise ValueError( f"Invalid model type. Must be {Model.get_model_types()}") @@ -97,6 +104,18 @@ def generate(self, num_records: int = 500) -> pd.DataFrame: """Generate synthetic data""" self.run.generate_data(num_records=num_records, max_invalid=None) return self.run.get_synthetic_data() + + def replace_nested_key(self, data, key, value): + """Replace nested keys""" + if isinstance(data, dict): + return { + k: value if k == key else self.replace_nested_key(v, key, value) + for k, v in data.items() + } + elif isinstance(data, list): + return [self.replace_nested_key(v, key, value) for v in data] + else: + return data def _preprocess_data( self, dataset_path: str, round_decimals: int = 4 From 25c77c668ac37eb7979660068415176ebb6631c2 Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Fri, 3 Jun 2022 11:48:15 -0700 Subject: [PATCH 03/20] Create __init__.py --- src/gretel_trainer/templates/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/gretel_trainer/templates/__init__.py diff --git a/src/gretel_trainer/templates/__init__.py b/src/gretel_trainer/templates/__init__.py new file mode 100644 index 00000000..e69de29b From 0af9e217c2902109785cb1ee5b875aa262f52baf Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Fri, 3 Jun 2022 11:52:34 -0700 Subject: [PATCH 04/20] Update setup.py --- setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 9f48f5a4..1be27d5f 100644 --- a/setup.py +++ b/setup.py @@ -8,5 +8,7 @@ version="0.0.1", package_dir={'': 'src'}, install_requires=install_requires, - packages=find_packages("src") + packages=find_packages("src"), + package_data={'': ['*.yaml']}, + include_package_data=True ) From a10cbf34860cb907242835196dec8269d2f42409 Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Fri, 3 Jun 2022 12:00:44 -0700 Subject: [PATCH 05/20] pretty print config --- src/gretel_trainer/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gretel_trainer/trainer.py b/src/gretel_trainer/trainer.py index c0929f1a..3f3c9499 100644 --- a/src/gretel_trainer/trainer.py +++ b/src/gretel_trainer/trainer.py @@ -2,6 +2,7 @@ from collections import namedtuple from enum import Enum +import json import logging import pandas as pd import pkgutil @@ -75,7 +76,7 @@ def __init__( # Update default config settings with kwargs by key for key, value in model_kwargs.items(): self.config = self.replace_nested_key(self.config, key, value) - logger.debug(self.config) + logger.debug(json.dumps(self.config,indent=2)) else: raise ValueError( From 0c2e5b705b51eaa7ed7971752213075bbf5e8b59 Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Fri, 3 Jun 2022 12:00:52 -0700 Subject: [PATCH 06/20] update defaults --- src/gretel_trainer/templates/gretel_lstm.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gretel_trainer/templates/gretel_lstm.yaml b/src/gretel_trainer/templates/gretel_lstm.yaml index 9c8efcb3..7af581bd 100644 --- a/src/gretel_trainer/templates/gretel_lstm.yaml +++ b/src/gretel_trainer/templates/gretel_lstm.yaml @@ -28,5 +28,5 @@ models: num_records: 5000 max_invalid: null privacy_filters: - outliers: medium - similarity: medium + outliers: null + similarity: null From 36c343915c10b8cb3b821a76253a4aaa1014ab91 Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Mon, 6 Jun 2022 10:50:09 -0700 Subject: [PATCH 07/20] refactor to model_params --- src/gretel_trainer/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gretel_trainer/trainer.py b/src/gretel_trainer/trainer.py index 3f3c9499..5fe97372 100644 --- a/src/gretel_trainer/trainer.py +++ b/src/gretel_trainer/trainer.py @@ -48,7 +48,7 @@ class Trainer: max_header_clusters (int, optional): Max number of clusters per batch. Defaults to 20. max_rows (int, optional): Max number of rows per batch. Defaults to 50000. model_type (str, optional): Options include ["GretelLSTM", "GretelCTGAN"]. Defaults to "GretelLSTM". - model_kwargs (dict, optional): Modify model configuration settings by key. E.g. {'epochs': 20} + model_params (dict, optional): Modify model configuration settings by key. E.g. {'epochs': 20} """ def __init__( @@ -57,7 +57,7 @@ def __init__( max_header_clusters: int = 20, max_rows: int = 50000, model_type: str = "GretelLSTM", - model_kwargs: dict = {} + model_params: dict = {} ): configure_session(api_key="prompt", cache="yes", validate=True) @@ -74,7 +74,7 @@ def __init__( self.config = yaml.load(config, Loader=yaml.FullLoader) # Update default config settings with kwargs by key - for key, value in model_kwargs.items(): + for key, value in model_params.items(): self.config = self.replace_nested_key(self.config, key, value) logger.debug(json.dumps(self.config,indent=2)) From 69342715bf60cb2e14bdb420d1b2430c85ccc744 Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Mon, 6 Jun 2022 10:51:23 -0700 Subject: [PATCH 08/20] sampling -> generation --- src/gretel_trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gretel_trainer/trainer.py b/src/gretel_trainer/trainer.py index 5fe97372..edef6d46 100644 --- a/src/gretel_trainer/trainer.py +++ b/src/gretel_trainer/trainer.py @@ -41,7 +41,7 @@ class Model(namedtuple("Model", "config"), ExtendedEnum): class Trainer: - """Automated generative training and sampling tool + """Automated model training and synthetic data generation tool Args: project_name (str, optional): Gretel project name. Defaults to "trainer". From 049ad0a5ea75dbfcc7df947e1c990e1144a49872 Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Mon, 6 Jun 2022 10:53:55 -0700 Subject: [PATCH 09/20] move blank df creation inside function --- src/gretel_trainer/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gretel_trainer/trainer.py b/src/gretel_trainer/trainer.py index edef6d46..5a3d03ee 100644 --- a/src/gretel_trainer/trainer.py +++ b/src/gretel_trainer/trainer.py @@ -127,9 +127,10 @@ def _preprocess_data( return tmp def _initialize_run( - self, df: pd.DataFrame = pd.DataFrame(), overwrite: bool = True + self, df: pd.DataFrame = None, overwrite: bool = True ) -> runner.StrategyRunner: """Create training jobs""" + df = pd.DataFrame() constraints = None if not df.empty: From d25e470905d46b35d68587534c363b2511fd3d05 Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Mon, 6 Jun 2022 11:27:40 -0700 Subject: [PATCH 10/20] refactor yaml tuples->lists --- src/gretel_trainer/templates/gretel_ctgan.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gretel_trainer/templates/gretel_ctgan.yaml b/src/gretel_trainer/templates/gretel_ctgan.yaml index db9eaeb1..c45c9819 100644 --- a/src/gretel_trainer/templates/gretel_ctgan.yaml +++ b/src/gretel_trainer/templates/gretel_ctgan.yaml @@ -5,8 +5,8 @@ models: data_source: __tmp__ params: batch_size: 200 - discriminator_dim: !!python/tuple [256, 256] + discriminator_dim: [256, 256] discriminator_lr: 0.00033 epochs: 50 - generator_dim: !!python/tuple [256, 256] + generator_dim: [256, 256] generator_lr: 1.0e-05 From 78fd319fa4436b0824371598cc47c2dff8caf9f0 Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Mon, 6 Jun 2022 11:27:52 -0700 Subject: [PATCH 11/20] rename internal method --- src/gretel_trainer/trainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/gretel_trainer/trainer.py b/src/gretel_trainer/trainer.py index 5a3d03ee..902f1970 100644 --- a/src/gretel_trainer/trainer.py +++ b/src/gretel_trainer/trainer.py @@ -73,9 +73,9 @@ def __init__( config = pkgutil.get_data(__name__, Model.get_config(model_type)) self.config = yaml.load(config, Loader=yaml.FullLoader) - # Update default config settings with kwargs by key + # Update default config settings with params by key for key, value in model_params.items(): - self.config = self.replace_nested_key(self.config, key, value) + self.config = self._replace_nested_key(self.config, key, value) logger.debug(json.dumps(self.config,indent=2)) else: @@ -106,15 +106,15 @@ def generate(self, num_records: int = 500) -> pd.DataFrame: self.run.generate_data(num_records=num_records, max_invalid=None) return self.run.get_synthetic_data() - def replace_nested_key(self, data, key, value): + def _replace_nested_key(self, data, key, value): """Replace nested keys""" if isinstance(data, dict): return { - k: value if k == key else self.replace_nested_key(v, key, value) + k: value if k == key else self._replace_nested_key(v, key, value) for k, v in data.items() } elif isinstance(data, list): - return [self.replace_nested_key(v, key, value) for v in data] + return [self._replace_nested_key(v, key, value) for v in data] else: return data From eef1bce2a50988a0ef59776881fb922b8451682f Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Mon, 6 Jun 2022 11:48:53 -0700 Subject: [PATCH 12/20] initialize blank DF --- src/gretel_trainer/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gretel_trainer/trainer.py b/src/gretel_trainer/trainer.py index 902f1970..2ae861ac 100644 --- a/src/gretel_trainer/trainer.py +++ b/src/gretel_trainer/trainer.py @@ -130,7 +130,8 @@ def _initialize_run( self, df: pd.DataFrame = None, overwrite: bool = True ) -> runner.StrategyRunner: """Create training jobs""" - df = pd.DataFrame() + if df is None: + df = pd.DataFrame() constraints = None if not df.empty: From d7dc898d6fcefcdb44bce35cdab54c11ad098574 Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Mon, 6 Jun 2022 12:10:19 -0700 Subject: [PATCH 13/20] download blueprint configs --- src/gretel_trainer/trainer.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/gretel_trainer/trainer.py b/src/gretel_trainer/trainer.py index 2ae861ac..7b872c70 100644 --- a/src/gretel_trainer/trainer.py +++ b/src/gretel_trainer/trainer.py @@ -5,8 +5,6 @@ import json import logging import pandas as pd -import pkgutil -import yaml from gretel_client import configure_session, ClientConfig from gretel_client.projects import create_or_get_unique_project @@ -36,8 +34,8 @@ def get_config(cls, model: str): class Model(namedtuple("Model", "config"), ExtendedEnum): """Enum to pair valid models and configurations""" - GretelLSTM = "templates/gretel_lstm.yaml" - GretelCTGAN = "templates/gretel_ctgan.yaml" + GretelLSTM = "synthetics/default" + GretelCTGAN = "synthetics/ctgan" class Trainer: @@ -70,8 +68,7 @@ def __init__( self.max_rows = max_rows if model_type in Model.get_types(): - config = pkgutil.get_data(__name__, Model.get_config(model_type)) - self.config = yaml.load(config, Loader=yaml.FullLoader) + self.config = read_model_config(Model.get_config(model_type)) # Update default config settings with params by key for key, value in model_params.items(): From e0f57478a0a0d87fe01db4fefa9ca052af980373 Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Tue, 7 Jun 2022 10:41:10 -0700 Subject: [PATCH 14/20] organize imports --- src/gretel_trainer/runner.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/gretel_trainer/runner.py b/src/gretel_trainer/runner.py index 8f4720b6..998e8506 100644 --- a/src/gretel_trainer/runner.py +++ b/src/gretel_trainer/runner.py @@ -22,30 +22,29 @@ """ from __future__ import annotations +import json +import logging import tempfile import time - from collections import Counter from concurrent.futures import ALL_COMPLETED, ThreadPoolExecutor, wait from copy import deepcopy from dataclasses import dataclass from functools import wraps -import json -import logging from pathlib import Path from typing import List, Optional, Union -import smart_open import pandas as pd - -from gretel_trainer.strategy import Partition, PartitionConstraints, PartitionStrategy - +import smart_open from gretel_client.projects import Project from gretel_client.projects.jobs import ACTIVE_STATES from gretel_client.projects.models import Model, Status from gretel_client.projects.records import RecordHandler -from gretel_client.users.users import get_me from gretel_client.rest import ApiException +from gretel_client.users.users import get_me + +from gretel_trainer.strategy import (Partition, PartitionConstraints, + PartitionStrategy) MODEL_ID = "model_id" HANDLER_ID = "handler_id" @@ -737,6 +736,7 @@ def generate_data( max_invalid: Optional[int] = None, clear_cache: bool = False, ): + if seed_df is None and not num_records: raise ValueError( "must provide a seed_df or num_records to generate") From 9fc441a9532f146f483ea1247c13b6c26ca59404 Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Tue, 7 Jun 2022 10:41:34 -0700 Subject: [PATCH 15/20] load() -> classmethod --- notebooks/trainer-examples.py | 11 ++-- src/gretel_trainer/trainer.py | 103 ++++++++++++++++++++++++++-------- 2 files changed, 84 insertions(+), 30 deletions(-) diff --git a/notebooks/trainer-examples.py b/notebooks/trainer-examples.py index 064e2eb7..11420974 100644 --- a/notebooks/trainer-examples.py +++ b/notebooks/trainer-examples.py @@ -1,4 +1,4 @@ -from gretel_trainer import trainer +from gretel_trainer import trainer, runner dataset = "https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/USAdultIncome5k.csv" @@ -18,13 +18,12 @@ #df = model.generate() # Specify synthetic model and update config params -#model = trainer.Trainer(model_type="GretelCTGAN", model_kwargs={'epochs': 20}) +#model = trainer.Trainer(model_type="GretelCTGAN", model_params={'epochs':2}) #model.train(dataset) #df = model.generate() # Load and generate data from an existing model -#model = trainer.Trainer() -#model.load() -#df = model.generate(num_records=500) +#model = trainer.Trainer.load() +#df = model.generate(num_records=70) -print(df) \ No newline at end of file +print(df) diff --git a/src/gretel_trainer/trainer.py b/src/gretel_trainer/trainer.py index 7b872c70..795b35ac 100644 --- a/src/gretel_trainer/trainer.py +++ b/src/gretel_trainer/trainer.py @@ -1,23 +1,26 @@ """Main Trainer Module""" -from collections import namedtuple -from enum import Enum import json import logging -import pandas as pd +import os.path +from collections import namedtuple +from enum import Enum -from gretel_client import configure_session, ClientConfig +import pandas as pd +from gretel_client import ClientConfig, configure_session from gretel_client.projects import create_or_get_unique_project -from gretel_client.projects.models import read_model_config from gretel_client.projects.jobs import Status +from gretel_client.projects.models import read_model_config from gretel_synthetics.utils.header_clusters import cluster -from gretel_trainer import strategy, runner - +from gretel_trainer import runner, strategy logger = logging.getLogger(__name__) logger.setLevel("DEBUG") +DEFAULT_PROJECT = "trainer" +DEFAULT_CACHE = f"{DEFAULT_PROJECT}-runner.json" + class ExtendedEnum(Enum): """Utility class for Model enum""" @@ -47,6 +50,8 @@ class Trainer: max_rows (int, optional): Max number of rows per batch. Defaults to 50000. model_type (str, optional): Options include ["GretelLSTM", "GretelCTGAN"]. Defaults to "GretelLSTM". model_params (dict, optional): Modify model configuration settings by key. E.g. {'epochs': 20} + cache_file (str, optional): Select a path to save or load the cache file. Default is `[project_name]-runner.json`. + overwrite (bool, optional): Overwrite previous progress. Defaults to True. """ def __init__( @@ -55,17 +60,22 @@ def __init__( max_header_clusters: int = 20, max_rows: int = 50000, model_type: str = "GretelLSTM", - model_params: dict = {} + model_params: dict = {}, + cache_file: str = None, + overwrite: bool = True, ): configure_session(api_key="prompt", cache="yes", validate=True) self.df = None self.dataset_path = None + self.run = None self.project_name = project_name self.project = create_or_get_unique_project(name=project_name) self.max_header_clusters = max_header_clusters self.max_rows = max_rows + self.overwrite = overwrite + self.cache_file = self._get_cache_file(cache_file) if model_type in Model.get_types(): self.config = read_model_config(Model.get_config(model_type)) @@ -73,41 +83,71 @@ def __init__( # Update default config settings with params by key for key, value in model_params.items(): self.config = self._replace_nested_key(self.config, key, value) - logger.debug(json.dumps(self.config,indent=2)) - else: raise ValueError( f"Invalid model type. Must be {Model.get_model_types()}") - def train(self, dataset_path: str, overwrite: bool = True, round_decimals: int = 4): + if self.overwrite: + logger.debug(json.dumps(self.config, indent=2)) + + @classmethod + def load( + cls, cache_file: str = DEFAULT_CACHE, project_name: str = DEFAULT_PROJECT + ) -> runner.StrategyRunner: + """Load an existing project from a cache. + + Args: + cache_file (str, optional): Valid file path to load the cache file from. Defaults to `[project-name]-runner.json` + + Returns: + Trainer: returns an initialized StrategyRunner class. + """ + project = create_or_get_unique_project(name=project_name) + model = cls(cache_file=cache_file, + project_name=project_name, overwrite=False) + + if not os.path.exists(cache_file): + raise ValueError( + f"Unable to find `{cache_file}`. Please specify a valid cache_file." + ) + + model.run = model._initialize_run(df=None, overwrite=model.overwrite) + return model + + def train(self, dataset_path: str, round_decimals: int = 4): """Train a model on the dataset Args: dataset_path (str): Path or URL to CSV - overwrite (bool, optional): Overwrite previous progress. Defaults to True. - round_decimals (int, optional): Round decimals in CSV as preprocessing step. Defaults to 4. + round_decimals (int, optional): Round decimals in CSV as preprocessing step. Defaults to `4`. """ self.dataset_path = dataset_path self.df = self._preprocess_data( dataset_path=dataset_path, round_decimals=round_decimals ) - self.run = self._initialize_run(df=self.df, overwrite=overwrite) + self.run = self._initialize_run(df=self.df, overwrite=self.overwrite) self.run.train_all_partitions() - def load(self): - """Load an existing strategy""" - self.run = self._initialize_run(overwrite=False) - def generate(self, num_records: int = 500) -> pd.DataFrame: - """Generate synthetic data""" - self.run.generate_data(num_records=num_records, max_invalid=None) + """Generate synthetic data + + Args: + num_records (int, optional): Number of records to generate from model. Defaults to 500. + + Returns: + pd.DataFrame: Synthetic data. + """ + self.run.generate_data( + num_records=num_records, max_invalid=None, clear_cache=True + ) return self.run.get_synthetic_data() - + def _replace_nested_key(self, data, key, value): """Replace nested keys""" if isinstance(data, dict): return { - k: value if k == key else self._replace_nested_key(v, key, value) + k: value if k == key else self._replace_nested_key( + v, key, value) for k, v in data.items() } elif isinstance(data, list): @@ -123,13 +163,28 @@ def _preprocess_data( tmp = tmp.round(round_decimals) return tmp + def _get_cache_file(self, cache_file: str) -> str: + """Select a path to store the runtime cache to initialize a model""" + if cache_file is None: + cache_file = f"{self.project_name}-runner.json" + + if os.path.exists(cache_file): + if self.overwrite: + logger.warning( + f"Overwriting existing run cache: {cache_file}.") + else: + logger.info(f"Using existing run cache: {cache_file}.") + else: + logger.info(f"Creating new run cache: {cache_file}.") + return cache_file + def _initialize_run( self, df: pd.DataFrame = None, overwrite: bool = True ) -> runner.StrategyRunner: """Create training jobs""" + constraints = None if df is None: df = pd.DataFrame() - constraints = None if not df.empty: header_clusters = cluster( @@ -146,7 +201,7 @@ def _initialize_run( run = runner.StrategyRunner( strategy_id=f"{self.project_name}", df=self.df, - cache_file=f"{self.project_name}-runner.json", + cache_file=self.cache_file, cache_overwrite=overwrite, model_config=self.config, partition_constraints=constraints, From ebd86c027b3300cff99674442cedc3051942610c Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Tue, 7 Jun 2022 10:42:52 -0700 Subject: [PATCH 16/20] use configs from blueprint repo --- src/gretel_trainer/templates/__init__.py | 0 .../templates/gretel_ctgan.yaml | 12 ------- src/gretel_trainer/templates/gretel_lstm.yaml | 32 ------------------- 3 files changed, 44 deletions(-) delete mode 100644 src/gretel_trainer/templates/__init__.py delete mode 100644 src/gretel_trainer/templates/gretel_ctgan.yaml delete mode 100644 src/gretel_trainer/templates/gretel_lstm.yaml diff --git a/src/gretel_trainer/templates/__init__.py b/src/gretel_trainer/templates/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/gretel_trainer/templates/gretel_ctgan.yaml b/src/gretel_trainer/templates/gretel_ctgan.yaml deleted file mode 100644 index c45c9819..00000000 --- a/src/gretel_trainer/templates/gretel_ctgan.yaml +++ /dev/null @@ -1,12 +0,0 @@ -schema_version: '1.0' - -models: -- ctgan: - data_source: __tmp__ - params: - batch_size: 200 - discriminator_dim: [256, 256] - discriminator_lr: 0.00033 - epochs: 50 - generator_dim: [256, 256] - generator_lr: 1.0e-05 diff --git a/src/gretel_trainer/templates/gretel_lstm.yaml b/src/gretel_trainer/templates/gretel_lstm.yaml deleted file mode 100644 index 7af581bd..00000000 --- a/src/gretel_trainer/templates/gretel_lstm.yaml +++ /dev/null @@ -1,32 +0,0 @@ -schema_version: "1.0" - -models: - - synthetics: - data_source: __tmp__ - params: - epochs: 100 - batch_size: 64 - vocab_size: 20000 - reset_states: False - learning_rate: 0.01 - rnn_units: 256 - dropout_rate: 0.2 - overwrite: True - early_stopping: True - gen_temp: 1.0 - predict_batch_size: 64 - validation_split: False - dp: False - dp_noise_multiplier: 0.001 - dp_l2_norm_clip: 5.0 - dp_microbatches: 1 - data_upsample_limit: 10000 - validators: - in_set_count: 10 - pattern_count: 10 - generate: - num_records: 5000 - max_invalid: null - privacy_filters: - outliers: null - similarity: null From 16da03483f0cceece548eb9175edd165ebb2b473 Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Tue, 7 Jun 2022 10:47:55 -0700 Subject: [PATCH 17/20] update tests --- tests/test_strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_strategy.py b/tests/test_strategy.py index 8440f816..74d2e990 100644 --- a/tests/test_strategy.py +++ b/tests/test_strategy.py @@ -6,7 +6,7 @@ import pytest from gretel_synthetics.utils.header_clusters import cluster -from trainer.strategy import PartitionConstraints, PartitionStrategy +from gretel_trainer.strategy import PartitionConstraints, PartitionStrategy @pytest.fixture(scope="module", autouse=True) From b6e3ff5766d7e384cb018a1f09fd11194fd511c3 Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Fri, 10 Jun 2022 07:11:05 -0700 Subject: [PATCH 18/20] update docs and setup --- README.md | 38 ++++++++++++++++++++++++++-------- notebooks/gretel-trainer.ipynb | 7 ++++++- setup.py | 14 ++++++++++++- 3 files changed, 48 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 83b3e3ac..6c3e858d 100644 --- a/README.md +++ b/README.md @@ -2,17 +2,37 @@ This code is designed to help users successfully train synthetic models on complex datasets with high row and column counts. The code works by intelligently dividing a dataset into a set of smaller datasets of correlated columns that can be parallelized and then joined together. -# Get Started +# Install -## Running the notebook -1. Launch the [Notebook](https://github.com/gretelai/trainer/blob/main/notebooks/gretel-trainer.ipynb) in [Google Colab](https://colab.research.google.com/github/gretelai/trainer/blob/main/notebooks/gretel-trainer.ipynb) or your preferred environment. -2. Add your dataset and [Gretel API](https://console.gretel.cloud) key to the notebook. -3. Generate synthetic data! +**Using `pip`:** + +```bash +pip install -U gretel-trainer +``` + +# Quickstart + +1. Add your [Gretel API](https://console.gretel.cloud) key via the Gretel CLI. +```bash +gretel configure +``` + +2. Train or fine-tune a model using the Gretel API -**NOTE**: Either delete the existing or choose a new cache file name if you are starting -a dataset run from scratch. +```python3 +from gretel_trainer import trainer + +dataset = "https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/USAdultIncome5k.csv" + +model = trainer.Trainer() +model.train(dataset) +``` + +3. Generate synthetic data! +```python3 +df = model.generate() +``` # TODOs / Roadmap -- [ ] Enable additional sampling from from trained models. -- [ ] Detect and label encode random UIDs (preprocessing). +- [ ] Enable conditional generation via SDK interface (supported in Notebooks currently). diff --git a/notebooks/gretel-trainer.ipynb b/notebooks/gretel-trainer.ipynb index dd3db750..3f683831 100644 --- a/notebooks/gretel-trainer.ipynb +++ b/notebooks/gretel-trainer.ipynb @@ -307,7 +307,12 @@ "id": "38e44df3" }, "outputs": [], - "source": [] + "source": [ + "# Use the model to generate additional data\n", + "\n", + "run.generate_data(num_records=5000, max_invalid=None, clear_cache=True)\n", + "run.get_synthetic_data()" + ] } ], "metadata": { diff --git a/setup.py b/setup.py index 1be27d5f..2f7c14b5 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,19 @@ version="0.0.1", package_dir={'': 'src'}, install_requires=install_requires, + python_requires=">=3.7", packages=find_packages("src"), package_data={'': ['*.yaml']}, - include_package_data=True + include_package_data=True, + description="Synthetic Data Generation with optional Differential Privacy", + url="https://github.com/gretelai/gretel-trainer", + license="http://www.apache.org/licenses/LICENSE-2.0", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: POSIX :: Linux", + "Operating System :: MacOS", + "Operating System :: Microsoft :: Windows", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ] ) From c40b2c704af342dd3f737e7a8e3367c481368a19 Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Fri, 10 Jun 2022 07:34:37 -0700 Subject: [PATCH 19/20] Update README.md --- README.md | 39 ++++++++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 6c3e858d..6219df6f 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,34 @@ # Gretel Trainer -This code is designed to help users successfully train synthetic models on complex datasets with high row and column counts. The code works by intelligently dividing a dataset into a set of smaller datasets of correlated columns that can be parallelized and then joined together. +This module is designed to provide a simple interface to help users successfully train synthetic models on complex datasets with high row and column counts, and offers features such as Cloud SaaS based training and multi-GPU based parallelization. Get started for free with an API key from [Gretel.ai](https://console.gretel.cloud). + +## Current functionality and features: + +* Synthetic data generators for text, tabular, and time-series data with the following + features: + * Balance datasets or boost a minority class using Conditional Data Generation. + * Automated data validation. + * Synthetic data quality reports. + * Privacy filters and optional differential privacy support. +* Multiple [model types supported](https://docs.gretel.ai/synthetics/models): + * `Gretel-LSTM` model type supports text, tabular, time-series, and conditional data generation. + * `Gretel-CTGAN` model type supports tabular and conditional data generation. + * `Gretel-GPT` natural language synthesis based on an open-source implementation of GPT-3 (coming soon). + * `Gretel-DGAN` multi-variate time series based on DoppelGANger (coming soon). + +## Try it out now! + +If you want to quickly get started synthesizing data with **Gretel.ai**, simply click the button below and follow the examples! + +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gretelai/gretel-trainer/blob/master/notebooks/trainer-examples.ipynb) + +## Join our Slack Workspace + +If you want to be part of the Gretel synthetic data community to receive announcements of the latest releases, +ask questions, suggest new features or participate in the development meetings, please join +our Slack Workspace! + +[![Slack](https://img.shields.io/badge/Slack%20Workspace-Join%20now!-36C5F0?logo=slack)](https://gretel.ai/slackinvite) # Install @@ -12,12 +40,13 @@ pip install -U gretel-trainer # Quickstart -1. Add your [Gretel API](https://console.gretel.cloud) key via the Gretel CLI. +### 1. Add your [Gretel API](https://console.gretel.cloud) key via the Gretel CLI. +Use the Gretel client to store your API key to disk. This step is optional, the trainer will prompt for an API key in the next step. ```bash gretel configure ``` -2. Train or fine-tune a model using the Gretel API +### 2. Train or fine-tune a model using the Gretel API ```python3 from gretel_trainer import trainer @@ -28,11 +57,11 @@ model = trainer.Trainer() model.train(dataset) ``` -3. Generate synthetic data! +### 3. Generate synthetic data! ```python3 df = model.generate() ``` -# TODOs / Roadmap +## TODOs / Roadmap - [ ] Enable conditional generation via SDK interface (supported in Notebooks currently). From 79749303e18c4a621198db462581bf5d142b9433 Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Fri, 10 Jun 2022 07:43:33 -0700 Subject: [PATCH 20/20] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6219df6f..9ca77e7b 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ This module is designed to provide a simple interface to help users successfully ## Try it out now! -If you want to quickly get started synthesizing data with **Gretel.ai**, simply click the button below and follow the examples! +If you want to quickly get started synthesizing data with **Gretel.ai**, simply click the button below and follow the examples. See additional Python3 and Jupyter Notebook examples in the `./notebooks` folder. [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gretelai/gretel-trainer/blob/master/notebooks/trainer-examples.ipynb)