Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark v2 #146

Merged
merged 2 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,13 @@ multi:
multilint:
python -m isort src/gretel_trainer/relational tests/relational
python -m black src/gretel_trainer/relational tests/relational

.PHONY: bench
bench:
python -m pyright src/gretel_trainer/benchmark tests/benchmark
python -m pytest tests/benchmark/

.PHONY: benchlint
benchlint:
python -m isort src/gretel_trainer/benchmark tests/benchmark
python -m black src/gretel_trainer/benchmark tests/benchmark
20 changes: 11 additions & 9 deletions notebooks/benchmark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
"metadata": {},
"source": [
"\n",
"When using your own data, indicate the datatype (select between: \"tabular_mixed\", \"tabular_numeric\", \"natural_language\", and \"time_series\"). Learn more in the [Benchmark docs](https://docs.gretel.ai/reference/benchmark#docs-internal-guid-31c7e29f-7fff-7936-54f8-737618a7e7f3).\n",
"When using your own data, indicate the datatype (select between: \"tabular\", \"natural_language\", and \"time_series\"). Learn more in the [Benchmark docs](https://docs.gretel.ai/reference/benchmark#docs-internal-guid-31c7e29f-7fff-7936-54f8-737618a7e7f3).\n",
"\n",
"Running in Google Colab? You can add your files to the Colab file system, and indicate the path like: \"/content/my_files/data.csv\""
]
Expand All @@ -85,7 +85,7 @@
"metadata": {},
"outputs": [],
"source": [
"# my_data = b.make_dataset([\"/PATH/TO/MY_DATASET.csv\"], datatype=\"INDICATE_DATATYPE\")"
"# my_data = b.create_dataset(\"/PATH/TO/MY_DATASET.csv\", name=\"my-data\", datatype=\"INDICATE_DATATYPE\")"
]
},
{
Expand All @@ -109,12 +109,14 @@
"metadata": {},
"outputs": [],
"source": [
"repo = b.GretelDatasetRepo()\n",
"\n",
"datasets = []\n",
"datasets = b.list_gretel_datasets() # selects all Benchmark datasets\n",
"datasets = repo.list_datasets() # selects all Benchmark datasets\n",
"\n",
"# Other sample commands\n",
"# datasets = b.list_gretel_datasets(datatype=\"time_series\") # select all time-series datasets\n",
"# datasets = b.list_gretel_datasets(datatype=\"tabular_mixed\", tags=[\"small\", \"marketing\"]) # select all tabular_mixed, size small, and marketing-related datasets\n",
"# datasets = repo.list_datasets(datatype=\"time_series\") # select all time-series datasets\n",
"# datasets = repo.list_datasets(datatype=\"tabular\", tags=[\"small\", \"marketing\"]) # select all tabular_mixed, size small, and marketing-related datasets\n",
"\n",
"# This will show you all the datasets in the Benchmark dataset bucket\n",
"[dataset.name for dataset in datasets]"
Expand All @@ -127,7 +129,7 @@
"outputs": [],
"source": [
"# Benchmark datasets are annotated with tags, so you can select based on your use case\n",
"b.list_gretel_dataset_tags() "
"repo.list_tags() "
]
},
{
Expand All @@ -139,8 +141,8 @@
"# For this demo, we will select two datasets by name:\n",
"# \"iris.csv\" - a publicly available dataset for predicting the class of the iris plant based on attributes\n",
"# \"processed_cleveland_heart_disease_uci.csv\" - a publicly available dataset for predicting presence of heart disease\n",
"iris = b.get_gretel_dataset(\"iris\")\n",
"heart_disease = b.get_gretel_dataset(\"processed_cleveland_heart_disease_uci\")"
"iris = repo.get_dataset(\"iris\")\n",
"heart_disease = repo.get_dataset(\"processed_cleveland_heart_disease_uci\")"
]
},
{
Expand Down Expand Up @@ -266,7 +268,7 @@
"metadata": {},
"outputs": [],
"source": [
"comparison = b.compare(datasets= [heart_disease, iris], models=[GretelLSTM, GretelAmplify])"
"comparison = b.compare(datasets=[heart_disease, iris], models=[GretelLSTM, GretelAmplify])"
]
},
{
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ requests~=2.25
scikit-learn~=1.0
smart-open[s3]~=5.2
sqlalchemy~=1.4
typing-extensions~=4.7
unflatten==0.1.1
104 changes: 29 additions & 75 deletions src/gretel_trainer/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,37 @@
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import List, Optional, Type, Union

import gretel_trainer.benchmark.compare as c
import gretel_trainer.benchmark.custom.datasets
import pandas as pd

from gretel_trainer.benchmark.core import Dataset, Datatype, ModelFactory
from gretel_trainer.benchmark.gretel.datasets import GretelDataset, GretelPublicDatasetRepo
import logging

from gretel_trainer.benchmark.core import BenchmarkConfig, Datatype
from gretel_trainer.benchmark.custom.datasets import create_dataset, make_dataset
from gretel_trainer.benchmark.entrypoints import compare, launch
from gretel_trainer.benchmark.gretel.datasets import GretelDatasetRepo
from gretel_trainer.benchmark.gretel.datasets_backwards_compatibility import (
get_gretel_dataset,
list_gretel_dataset_tags,
list_gretel_datasets,
)
from gretel_trainer.benchmark.gretel.models import (
GretelACTGAN,
GretelAmplify,
GretelAuto,
GretelACTGAN,
GretelDGAN,
GretelGPTX,
GretelLSTM,
GretelModel,
)
from gretel_trainer.benchmark.gretel.sdk import ActualGretelSDK
from gretel_trainer.benchmark.gretel.trainer import ActualGretelTrainer

BENCHMARK_DIR = "./.benchmark"

repo = GretelPublicDatasetRepo(
bucket="gretel-datasets",
region="us-west-2",
load_dir=f"{BENCHMARK_DIR}/gretel_datasets",
)


def get_gretel_dataset(name: str) -> GretelDataset:
return repo.get_dataset(name)


def list_gretel_datasets(
datatype: Optional[Union[Datatype, str]] = None, tags: Optional[List[str]] = None
) -> List[GretelDataset]:
return repo.list_datasets(datatype, tags)


def list_gretel_dataset_tags() -> List[str]:
return repo.list_tags()


def make_dataset(
sources: Union[List[str], List[pd.DataFrame]],
*,
datatype: Union[Datatype, str],
namespace: Optional[str] = None,
delimiter: str = ",",
) -> Dataset:
return gretel_trainer.benchmark.custom.datasets.make_dataset(
sources,
datatype=datatype,
namespace=namespace,
delimiter=delimiter,
local_dir=BENCHMARK_DIR,
)


def compare(
*,
datasets: List[Dataset],
models: List[Union[ModelFactory, Type[GretelModel]]],
auto_clean: bool = True,
) -> c.Comparison:
return c.compare(
datasets=datasets,
models=models,
runtime_config=c.RuntimeConfig(
local_dir=BENCHMARK_DIR,
project_prefix=f"benchmark-{_timestamp()}",
thread_pool=ThreadPoolExecutor(5),
wait_secs=10,
auto_clean=auto_clean,
),
gretel_sdk=ActualGretelSDK,
gretel_trainer_factory=ActualGretelTrainer,
)


def _timestamp() -> str:
return datetime.now().strftime("%Y%m%d%H%M%S")
log_format = "%(levelname)s - %(asctime)s - %(message)s"
time_format = "%Y-%m-%d %H:%M:%S"
formatter = logging.Formatter(log_format, time_format)
handler = logging.StreamHandler()
handler.setFormatter(formatter)

# Clear out any existing root handlers
# (This prevents duplicate log output in Colab)
for root_handler in logging.root.handlers:
logging.root.removeHandler(root_handler)

# Configure benchmark loggers
logger = logging.getLogger("gretel_trainer.benchmark")
logger.handlers.clear()
logger.addHandler(handler)
logger.setLevel("INFO")
Loading