-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
33 changed files
with
2,033 additions
and
1,552 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,4 +12,5 @@ requests~=2.25 | |
scikit-learn~=1.0 | ||
smart-open[s3]~=5.2 | ||
sqlalchemy~=1.4 | ||
typing-extensions~=4.7 | ||
unflatten==0.1.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,83 +1,37 @@ | ||
from concurrent.futures import ThreadPoolExecutor | ||
from datetime import datetime | ||
from typing import List, Optional, Type, Union | ||
|
||
import gretel_trainer.benchmark.compare as c | ||
import gretel_trainer.benchmark.custom.datasets | ||
import pandas as pd | ||
|
||
from gretel_trainer.benchmark.core import Dataset, Datatype, ModelFactory | ||
from gretel_trainer.benchmark.gretel.datasets import GretelDataset, GretelPublicDatasetRepo | ||
import logging | ||
|
||
from gretel_trainer.benchmark.core import BenchmarkConfig, Datatype | ||
from gretel_trainer.benchmark.custom.datasets import create_dataset, make_dataset | ||
from gretel_trainer.benchmark.entrypoints import compare, launch | ||
from gretel_trainer.benchmark.gretel.datasets import GretelDatasetRepo | ||
from gretel_trainer.benchmark.gretel.datasets_backwards_compatibility import ( | ||
get_gretel_dataset, | ||
list_gretel_dataset_tags, | ||
list_gretel_datasets, | ||
) | ||
from gretel_trainer.benchmark.gretel.models import ( | ||
GretelACTGAN, | ||
GretelAmplify, | ||
GretelAuto, | ||
GretelACTGAN, | ||
GretelDGAN, | ||
GretelGPTX, | ||
GretelLSTM, | ||
GretelModel, | ||
) | ||
from gretel_trainer.benchmark.gretel.sdk import ActualGretelSDK | ||
from gretel_trainer.benchmark.gretel.trainer import ActualGretelTrainer | ||
|
||
BENCHMARK_DIR = "./.benchmark" | ||
|
||
repo = GretelPublicDatasetRepo( | ||
bucket="gretel-datasets", | ||
region="us-west-2", | ||
load_dir=f"{BENCHMARK_DIR}/gretel_datasets", | ||
) | ||
|
||
|
||
def get_gretel_dataset(name: str) -> GretelDataset: | ||
return repo.get_dataset(name) | ||
|
||
|
||
def list_gretel_datasets( | ||
datatype: Optional[Union[Datatype, str]] = None, tags: Optional[List[str]] = None | ||
) -> List[GretelDataset]: | ||
return repo.list_datasets(datatype, tags) | ||
|
||
|
||
def list_gretel_dataset_tags() -> List[str]: | ||
return repo.list_tags() | ||
|
||
|
||
def make_dataset( | ||
sources: Union[List[str], List[pd.DataFrame]], | ||
*, | ||
datatype: Union[Datatype, str], | ||
namespace: Optional[str] = None, | ||
delimiter: str = ",", | ||
) -> Dataset: | ||
return gretel_trainer.benchmark.custom.datasets.make_dataset( | ||
sources, | ||
datatype=datatype, | ||
namespace=namespace, | ||
delimiter=delimiter, | ||
local_dir=BENCHMARK_DIR, | ||
) | ||
|
||
|
||
def compare( | ||
*, | ||
datasets: List[Dataset], | ||
models: List[Union[ModelFactory, Type[GretelModel]]], | ||
auto_clean: bool = True, | ||
) -> c.Comparison: | ||
return c.compare( | ||
datasets=datasets, | ||
models=models, | ||
runtime_config=c.RuntimeConfig( | ||
local_dir=BENCHMARK_DIR, | ||
project_prefix=f"benchmark-{_timestamp()}", | ||
thread_pool=ThreadPoolExecutor(5), | ||
wait_secs=10, | ||
auto_clean=auto_clean, | ||
), | ||
gretel_sdk=ActualGretelSDK, | ||
gretel_trainer_factory=ActualGretelTrainer, | ||
) | ||
|
||
|
||
def _timestamp() -> str: | ||
return datetime.now().strftime("%Y%m%d%H%M%S") | ||
log_format = "%(levelname)s - %(asctime)s - %(message)s" | ||
time_format = "%Y-%m-%d %H:%M:%S" | ||
formatter = logging.Formatter(log_format, time_format) | ||
handler = logging.StreamHandler() | ||
handler.setFormatter(formatter) | ||
|
||
# Clear out any existing root handlers | ||
# (This prevents duplicate log output in Colab) | ||
for root_handler in logging.root.handlers: | ||
logging.root.removeHandler(root_handler) | ||
|
||
# Configure benchmark loggers | ||
logger = logging.getLogger("gretel_trainer.benchmark") | ||
logger.handlers.clear() | ||
logger.addHandler(handler) | ||
logger.setLevel("INFO") |
Oops, something went wrong.