Skip to content

Commit

Permalink
Benchmark v2 (#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeknep authored Jul 31, 2023
1 parent 7b4d9c6 commit 4f6fbd5
Show file tree
Hide file tree
Showing 33 changed files with 2,033 additions and 1,552 deletions.
10 changes: 10 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,13 @@ multi:
multilint:
python -m isort src/gretel_trainer/relational tests/relational
python -m black src/gretel_trainer/relational tests/relational

.PHONY: bench
bench:
python -m pyright src/gretel_trainer/benchmark tests/benchmark
python -m pytest tests/benchmark/

.PHONY: benchlint
benchlint:
python -m isort src/gretel_trainer/benchmark tests/benchmark
python -m black src/gretel_trainer/benchmark tests/benchmark
20 changes: 11 additions & 9 deletions notebooks/benchmark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
"metadata": {},
"source": [
"\n",
"When using your own data, indicate the datatype (select between: \"tabular_mixed\", \"tabular_numeric\", \"natural_language\", and \"time_series\"). Learn more in the [Benchmark docs](https://docs.gretel.ai/reference/benchmark#docs-internal-guid-31c7e29f-7fff-7936-54f8-737618a7e7f3).\n",
"When using your own data, indicate the datatype (select between: \"tabular\", \"natural_language\", and \"time_series\"). Learn more in the [Benchmark docs](https://docs.gretel.ai/reference/benchmark#docs-internal-guid-31c7e29f-7fff-7936-54f8-737618a7e7f3).\n",
"\n",
"Running in Google Colab? You can add your files to the Colab file system, and indicate the path like: \"/content/my_files/data.csv\""
]
Expand All @@ -85,7 +85,7 @@
"metadata": {},
"outputs": [],
"source": [
"# my_data = b.make_dataset([\"/PATH/TO/MY_DATASET.csv\"], datatype=\"INDICATE_DATATYPE\")"
"# my_data = b.create_dataset(\"/PATH/TO/MY_DATASET.csv\", name=\"my-data\", datatype=\"INDICATE_DATATYPE\")"
]
},
{
Expand All @@ -109,12 +109,14 @@
"metadata": {},
"outputs": [],
"source": [
"repo = b.GretelDatasetRepo()\n",
"\n",
"datasets = []\n",
"datasets = b.list_gretel_datasets() # selects all Benchmark datasets\n",
"datasets = repo.list_datasets() # selects all Benchmark datasets\n",
"\n",
"# Other sample commands\n",
"# datasets = b.list_gretel_datasets(datatype=\"time_series\") # select all time-series datasets\n",
"# datasets = b.list_gretel_datasets(datatype=\"tabular_mixed\", tags=[\"small\", \"marketing\"]) # select all tabular_mixed, size small, and marketing-related datasets\n",
"# datasets = repo.list_datasets(datatype=\"time_series\") # select all time-series datasets\n",
"# datasets = repo.list_datasets(datatype=\"tabular\", tags=[\"small\", \"marketing\"]) # select all tabular_mixed, size small, and marketing-related datasets\n",
"\n",
"# This will show you all the datasets in the Benchmark dataset bucket\n",
"[dataset.name for dataset in datasets]"
Expand All @@ -127,7 +129,7 @@
"outputs": [],
"source": [
"# Benchmark datasets are annotated with tags, so you can select based on your use case\n",
"b.list_gretel_dataset_tags() "
"repo.list_tags() "
]
},
{
Expand All @@ -139,8 +141,8 @@
"# For this demo, we will select two datasets by name:\n",
"# \"iris.csv\" - a publicly available dataset for predicting the class of the iris plant based on attributes\n",
"# \"processed_cleveland_heart_disease_uci.csv\" - a publicly available dataset for predicting presence of heart disease\n",
"iris = b.get_gretel_dataset(\"iris\")\n",
"heart_disease = b.get_gretel_dataset(\"processed_cleveland_heart_disease_uci\")"
"iris = repo.get_dataset(\"iris\")\n",
"heart_disease = repo.get_dataset(\"processed_cleveland_heart_disease_uci\")"
]
},
{
Expand Down Expand Up @@ -266,7 +268,7 @@
"metadata": {},
"outputs": [],
"source": [
"comparison = b.compare(datasets= [heart_disease, iris], models=[GretelLSTM, GretelAmplify])"
"comparison = b.compare(datasets=[heart_disease, iris], models=[GretelLSTM, GretelAmplify])"
]
},
{
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ requests~=2.25
scikit-learn~=1.0
smart-open[s3]~=5.2
sqlalchemy~=1.4
typing-extensions~=4.7
unflatten==0.1.1
104 changes: 29 additions & 75 deletions src/gretel_trainer/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,37 @@
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import List, Optional, Type, Union

import gretel_trainer.benchmark.compare as c
import gretel_trainer.benchmark.custom.datasets
import pandas as pd

from gretel_trainer.benchmark.core import Dataset, Datatype, ModelFactory
from gretel_trainer.benchmark.gretel.datasets import GretelDataset, GretelPublicDatasetRepo
import logging

from gretel_trainer.benchmark.core import BenchmarkConfig, Datatype
from gretel_trainer.benchmark.custom.datasets import create_dataset, make_dataset
from gretel_trainer.benchmark.entrypoints import compare, launch
from gretel_trainer.benchmark.gretel.datasets import GretelDatasetRepo
from gretel_trainer.benchmark.gretel.datasets_backwards_compatibility import (
get_gretel_dataset,
list_gretel_dataset_tags,
list_gretel_datasets,
)
from gretel_trainer.benchmark.gretel.models import (
GretelACTGAN,
GretelAmplify,
GretelAuto,
GretelACTGAN,
GretelDGAN,
GretelGPTX,
GretelLSTM,
GretelModel,
)
from gretel_trainer.benchmark.gretel.sdk import ActualGretelSDK
from gretel_trainer.benchmark.gretel.trainer import ActualGretelTrainer

BENCHMARK_DIR = "./.benchmark"

repo = GretelPublicDatasetRepo(
bucket="gretel-datasets",
region="us-west-2",
load_dir=f"{BENCHMARK_DIR}/gretel_datasets",
)


def get_gretel_dataset(name: str) -> GretelDataset:
return repo.get_dataset(name)


def list_gretel_datasets(
datatype: Optional[Union[Datatype, str]] = None, tags: Optional[List[str]] = None
) -> List[GretelDataset]:
return repo.list_datasets(datatype, tags)


def list_gretel_dataset_tags() -> List[str]:
return repo.list_tags()


def make_dataset(
sources: Union[List[str], List[pd.DataFrame]],
*,
datatype: Union[Datatype, str],
namespace: Optional[str] = None,
delimiter: str = ",",
) -> Dataset:
return gretel_trainer.benchmark.custom.datasets.make_dataset(
sources,
datatype=datatype,
namespace=namespace,
delimiter=delimiter,
local_dir=BENCHMARK_DIR,
)


def compare(
*,
datasets: List[Dataset],
models: List[Union[ModelFactory, Type[GretelModel]]],
auto_clean: bool = True,
) -> c.Comparison:
return c.compare(
datasets=datasets,
models=models,
runtime_config=c.RuntimeConfig(
local_dir=BENCHMARK_DIR,
project_prefix=f"benchmark-{_timestamp()}",
thread_pool=ThreadPoolExecutor(5),
wait_secs=10,
auto_clean=auto_clean,
),
gretel_sdk=ActualGretelSDK,
gretel_trainer_factory=ActualGretelTrainer,
)


def _timestamp() -> str:
return datetime.now().strftime("%Y%m%d%H%M%S")
log_format = "%(levelname)s - %(asctime)s - %(message)s"
time_format = "%Y-%m-%d %H:%M:%S"
formatter = logging.Formatter(log_format, time_format)
handler = logging.StreamHandler()
handler.setFormatter(formatter)

# Clear out any existing root handlers
# (This prevents duplicate log output in Colab)
for root_handler in logging.root.handlers:
logging.root.removeHandler(root_handler)

# Configure benchmark loggers
logger = logging.getLogger("gretel_trainer.benchmark")
logger.handlers.clear()
logger.addHandler(handler)
logger.setLevel("INFO")
Loading

0 comments on commit 4f6fbd5

Please sign in to comment.