Skip to content

Commit

Permalink
Export helper as fixture instead of importing from test package (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeknep authored Aug 1, 2023
1 parent e5fd451 commit a5671c4
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 174 deletions.
63 changes: 0 additions & 63 deletions tests/benchmark/mocks.py

This file was deleted.

68 changes: 61 additions & 7 deletions tests/benchmark/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas.testing as pdtest
import pytest
from gretel_client.projects.jobs import Status
from gretel_client.projects.models import read_model_config

from gretel_trainer.benchmark import (
BenchmarkConfig,
Expand All @@ -17,13 +18,66 @@
create_dataset,
launch,
)
from tests.benchmark.mocks import (
DoNothingModel,
FailsToGenerate,
FailsToTrain,
SharedDictLstm,
TailoredActgan,
)
from gretel_trainer.benchmark.core import Dataset
from gretel_trainer.benchmark.gretel.models import GretelModel


class DoNothingModel:
def train(self, source: Dataset, **kwargs) -> None:
pass

def generate(self, **kwargs) -> pd.DataFrame:
return pd.DataFrame()


class FailsToTrain:
def train(self, source: Dataset, **kwargs) -> None:
raise Exception("failed")

def generate(self, **kwargs) -> pd.DataFrame:
return pd.DataFrame()


class FailsToGenerate:
def train(self, source: Dataset, **kwargs) -> None:
pass

def generate(self, **kwargs) -> pd.DataFrame:
raise Exception("failed")


class TailoredActgan(GretelModel):
@property
def config(self):
c = read_model_config("synthetics/tabular-actgan")
c["models"][0]["actgan"]["params"]["epochs"] = 100
return c


class SharedDictLstm(GretelModel):
config = {
"schema_version": "1.0",
"name": "tabular-lstm",
"models": [
{
"synthetics": {
"data_source": "__tmp__",
"params": {
"epochs": "auto",
"vocab_size": "auto",
"learning_rate": "auto",
"batch_size": "auto",
"rnn_units": "auto",
},
"generate": {"num_records": 5000},
"privacy_filters": {
"outliers": "auto",
"similarity": "auto",
},
}
}
],
}


def test_run_with_gretel_dataset(working_dir, project, evaluate_report_path, iris):
Expand Down
10 changes: 7 additions & 3 deletions tests/relational/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sqlite3
import tempfile
from pathlib import Path
from typing import Generator
from typing import Callable, Generator
from unittest.mock import Mock, patch

import pandas as pd
Expand Down Expand Up @@ -32,8 +32,12 @@ def static_suffix(request):
yield make_suffix


# Doesn't work well as a fixture due to the need for an input param
def get_invented_table_suffix(make_suffix_execution_number: int):
@pytest.fixture()
def get_invented_table_suffix() -> Callable[[int], str]:
return _get_invented_table_suffix


def _get_invented_table_suffix(make_suffix_execution_number: int):
return f"invented_{str(make_suffix_execution_number)}"


Expand Down
3 changes: 1 addition & 2 deletions tests/relational/test_backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
BackupSyntheticsTrain,
BackupTransformsTrain,
)
from tests.relational.conftest import get_invented_table_suffix


def test_backup_relational_data(trips):
Expand All @@ -37,7 +36,7 @@ def test_backup_relational_data(trips):
assert BackupRelationalData.from_relational_data(trips) == expected


def test_backup_relational_data_with_json(documents):
def test_backup_relational_data_with_json(documents, get_invented_table_suffix):
purchases_root_invented_table = f"purchases_{get_invented_table_suffix(1)}"
purchases_data_years_invented_table = f"purchases_{get_invented_table_suffix(2)}"

Expand Down
Loading

0 comments on commit a5671c4

Please sign in to comment.