-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from gretelai/aw/trainer-module
DRAFT - Aw/trainer module
- Loading branch information
Showing
9 changed files
with
417 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,67 @@ | ||
# Gretel Trainer | ||
|
||
This code is designed to help users successfully train synthetic models on complex datasets with high row and column counts. The code works by intelligently dividing a dataset into a set of smaller datasets of correlated columns that can be parallelized and then joined together. | ||
This module is designed to provide a simple interface to help users successfully train synthetic models on complex datasets with high row and column counts, and offers features such as Cloud SaaS based training and multi-GPU based parallelization. Get started for free with an API key from [Gretel.ai](https://console.gretel.cloud). | ||
|
||
# Get Started | ||
## Current functionality and features: | ||
|
||
## Running the notebook | ||
1. Launch the [Notebook](https://github.com/gretelai/trainer/blob/main/notebooks/gretel-trainer.ipynb) in [Google Colab](https://colab.research.google.com/github/gretelai/trainer/blob/main/notebooks/gretel-trainer.ipynb) or your preferred environment. | ||
2. Add your dataset and [Gretel API](https://console.gretel.cloud) key to the notebook. | ||
3. Generate synthetic data! | ||
* Synthetic data generators for text, tabular, and time-series data with the following | ||
features: | ||
* Balance datasets or boost a minority class using Conditional Data Generation. | ||
* Automated data validation. | ||
* Synthetic data quality reports. | ||
* Privacy filters and optional differential privacy support. | ||
* Multiple [model types supported](https://docs.gretel.ai/synthetics/models): | ||
* `Gretel-LSTM` model type supports text, tabular, time-series, and conditional data generation. | ||
* `Gretel-CTGAN` model type supports tabular and conditional data generation. | ||
* `Gretel-GPT` natural language synthesis based on an open-source implementation of GPT-3 (coming soon). | ||
* `Gretel-DGAN` multi-variate time series based on DoppelGANger (coming soon). | ||
|
||
## Try it out now! | ||
|
||
**NOTE**: Either delete the existing or choose a new cache file name if you are starting | ||
a dataset run from scratch. | ||
If you want to quickly get started synthesizing data with **Gretel.ai**, simply click the button below and follow the examples. See additional Python3 and Jupyter Notebook examples in the `./notebooks` folder. | ||
|
||
# TODOs / Roadmap | ||
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gretelai/gretel-trainer/blob/master/notebooks/trainer-examples.ipynb) | ||
|
||
- [ ] Enable additional sampling from from trained models. | ||
- [ ] Detect and label encode random UIDs (preprocessing). | ||
## Join our Slack Workspace | ||
|
||
If you want to be part of the Gretel synthetic data community to receive announcements of the latest releases, | ||
ask questions, suggest new features or participate in the development meetings, please join | ||
our Slack Workspace! | ||
|
||
[![Slack](https://img.shields.io/badge/Slack%20Workspace-Join%20now!-36C5F0?logo=slack)](https://gretel.ai/slackinvite) | ||
|
||
# Install | ||
|
||
**Using `pip`:** | ||
|
||
```bash | ||
pip install -U gretel-trainer | ||
``` | ||
|
||
# Quickstart | ||
|
||
### 1. Add your [Gretel API](https://console.gretel.cloud) key via the Gretel CLI. | ||
Use the Gretel client to store your API key to disk. This step is optional, the trainer will prompt for an API key in the next step. | ||
```bash | ||
gretel configure | ||
``` | ||
|
||
### 2. Train or fine-tune a model using the Gretel API | ||
|
||
```python3 | ||
from gretel_trainer import trainer | ||
|
||
dataset = "https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/USAdultIncome5k.csv" | ||
|
||
model = trainer.Trainer() | ||
model.train(dataset) | ||
``` | ||
|
||
### 3. Generate synthetic data! | ||
```python3 | ||
df = model.generate() | ||
``` | ||
|
||
## TODOs / Roadmap | ||
|
||
- [ ] Enable conditional generation via SDK interface (supported in Notebooks currently). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from gretel_trainer import trainer, runner | ||
|
||
dataset = "https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/USAdultIncome5k.csv" | ||
|
||
# Simplest example | ||
model = trainer.Trainer() | ||
model.train(dataset) | ||
df = model.generate() | ||
|
||
# Specify underlying model | ||
#model = trainer.Trainer(model_type="GretelLSTM") | ||
#model.train(dataset) | ||
#df = model.generate() | ||
|
||
# Update trainer parameters | ||
#model = trainer.Trainer(max_header_clusters=20, max_rows=50000) | ||
#model.train(dataset) | ||
#df = model.generate() | ||
|
||
# Specify synthetic model and update config params | ||
#model = trainer.Trainer(model_type="GretelCTGAN", model_params={'epochs':2}) | ||
#model.train(dataset) | ||
#df = model.generate() | ||
|
||
# Load and generate data from an existing model | ||
#model = trainer.Trainer.load() | ||
#df = model.generate(num_records=70) | ||
|
||
print(df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
Oops, something went wrong.