diff --git a/README.md b/README.md index 83b3e3ac..9ca77e7b 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,67 @@ # Gretel Trainer -This code is designed to help users successfully train synthetic models on complex datasets with high row and column counts. The code works by intelligently dividing a dataset into a set of smaller datasets of correlated columns that can be parallelized and then joined together. +This module is designed to provide a simple interface to help users successfully train synthetic models on complex datasets with high row and column counts, and offers features such as Cloud SaaS based training and multi-GPU based parallelization. Get started for free with an API key from [Gretel.ai](https://console.gretel.cloud). -# Get Started +## Current functionality and features: -## Running the notebook -1. Launch the [Notebook](https://github.com/gretelai/trainer/blob/main/notebooks/gretel-trainer.ipynb) in [Google Colab](https://colab.research.google.com/github/gretelai/trainer/blob/main/notebooks/gretel-trainer.ipynb) or your preferred environment. -2. Add your dataset and [Gretel API](https://console.gretel.cloud) key to the notebook. -3. Generate synthetic data! +* Synthetic data generators for text, tabular, and time-series data with the following + features: + * Balance datasets or boost a minority class using Conditional Data Generation. + * Automated data validation. + * Synthetic data quality reports. + * Privacy filters and optional differential privacy support. +* Multiple [model types supported](https://docs.gretel.ai/synthetics/models): + * `Gretel-LSTM` model type supports text, tabular, time-series, and conditional data generation. + * `Gretel-CTGAN` model type supports tabular and conditional data generation. + * `Gretel-GPT` natural language synthesis based on an open-source implementation of GPT-3 (coming soon). + * `Gretel-DGAN` multi-variate time series based on DoppelGANger (coming soon). + +## Try it out now! -**NOTE**: Either delete the existing or choose a new cache file name if you are starting -a dataset run from scratch. +If you want to quickly get started synthesizing data with **Gretel.ai**, simply click the button below and follow the examples. See additional Python3 and Jupyter Notebook examples in the `./notebooks` folder. -# TODOs / Roadmap +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gretelai/gretel-trainer/blob/master/notebooks/trainer-examples.ipynb) -- [ ] Enable additional sampling from from trained models. -- [ ] Detect and label encode random UIDs (preprocessing). +## Join our Slack Workspace + +If you want to be part of the Gretel synthetic data community to receive announcements of the latest releases, +ask questions, suggest new features or participate in the development meetings, please join +our Slack Workspace! + +[![Slack](https://img.shields.io/badge/Slack%20Workspace-Join%20now!-36C5F0?logo=slack)](https://gretel.ai/slackinvite) + +# Install + +**Using `pip`:** + +```bash +pip install -U gretel-trainer +``` + +# Quickstart + +### 1. Add your [Gretel API](https://console.gretel.cloud) key via the Gretel CLI. +Use the Gretel client to store your API key to disk. This step is optional, the trainer will prompt for an API key in the next step. +```bash +gretel configure +``` + +### 2. Train or fine-tune a model using the Gretel API + +```python3 +from gretel_trainer import trainer + +dataset = "https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/USAdultIncome5k.csv" + +model = trainer.Trainer() +model.train(dataset) +``` + +### 3. Generate synthetic data! +```python3 +df = model.generate() +``` + +## TODOs / Roadmap + +- [ ] Enable conditional generation via SDK interface (supported in Notebooks currently). diff --git a/notebooks/gretel-trainer.ipynb b/notebooks/gretel-trainer.ipynb index dd3db750..3f683831 100644 --- a/notebooks/gretel-trainer.ipynb +++ b/notebooks/gretel-trainer.ipynb @@ -307,7 +307,12 @@ "id": "38e44df3" }, "outputs": [], - "source": [] + "source": [ + "# Use the model to generate additional data\n", + "\n", + "run.generate_data(num_records=5000, max_invalid=None, clear_cache=True)\n", + "run.get_synthetic_data()" + ] } ], "metadata": { diff --git a/notebooks/trainer-examples.py b/notebooks/trainer-examples.py new file mode 100644 index 00000000..11420974 --- /dev/null +++ b/notebooks/trainer-examples.py @@ -0,0 +1,29 @@ +from gretel_trainer import trainer, runner + +dataset = "https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/USAdultIncome5k.csv" + +# Simplest example +model = trainer.Trainer() +model.train(dataset) +df = model.generate() + +# Specify underlying model +#model = trainer.Trainer(model_type="GretelLSTM") +#model.train(dataset) +#df = model.generate() + +# Update trainer parameters +#model = trainer.Trainer(max_header_clusters=20, max_rows=50000) +#model.train(dataset) +#df = model.generate() + +# Specify synthetic model and update config params +#model = trainer.Trainer(model_type="GretelCTGAN", model_params={'epochs':2}) +#model.train(dataset) +#df = model.generate() + +# Load and generate data from an existing model +#model = trainer.Trainer.load() +#df = model.generate(num_records=70) + +print(df) diff --git a/setup.py b/setup.py index acd8be69..2f7c14b5 100644 --- a/setup.py +++ b/setup.py @@ -4,9 +4,23 @@ local_path = pathlib.Path(__file__).parent install_requires = (local_path / "requirements.txt").read_text().splitlines() -setup(name="trainer", +setup(name="gretel-trainer", version="0.0.1", package_dir={'': 'src'}, install_requires=install_requires, - packages=find_packages("src") + python_requires=">=3.7", + packages=find_packages("src"), + package_data={'': ['*.yaml']}, + include_package_data=True, + description="Synthetic Data Generation with optional Differential Privacy", + url="https://github.com/gretelai/gretel-trainer", + license="http://www.apache.org/licenses/LICENSE-2.0", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: POSIX :: Linux", + "Operating System :: MacOS", + "Operating System :: Microsoft :: Windows", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ] ) diff --git a/src/trainer/__init__.py b/src/gretel_trainer/__init__.py similarity index 100% rename from src/trainer/__init__.py rename to src/gretel_trainer/__init__.py diff --git a/src/trainer/runner.py b/src/gretel_trainer/runner.py similarity index 89% rename from src/trainer/runner.py rename to src/gretel_trainer/runner.py index 6629b127..998e8506 100644 --- a/src/trainer/runner.py +++ b/src/gretel_trainer/runner.py @@ -22,28 +22,29 @@ """ from __future__ import annotations +import json +import logging import tempfile import time - from collections import Counter from concurrent.futures import ALL_COMPLETED, ThreadPoolExecutor, wait from copy import deepcopy from dataclasses import dataclass from functools import wraps -import logging from pathlib import Path from typing import List, Optional, Union import pandas as pd - -from trainer.strategy import Partition, PartitionConstraints, PartitionStrategy - +import smart_open from gretel_client.projects import Project from gretel_client.projects.jobs import ACTIVE_STATES from gretel_client.projects.models import Model, Status from gretel_client.projects.records import RecordHandler -from gretel_client.users.users import get_me from gretel_client.rest import ApiException +from gretel_client.users.users import get_me + +from gretel_trainer.strategy import (Partition, PartitionConstraints, + PartitionStrategy) MODEL_ID = "model_id" HANDLER_ID = "handler_id" @@ -98,7 +99,7 @@ def _remote_dataframe_fetcher(payload: RemoteDFPayload) -> RemoteDFPayload: # We need the model object no matter what model = Model(payload.project, model_id=payload.uid) job = model - + # if we are downloading handler data, we reset our job # to the specific handler object if payload.job_type == "run": @@ -109,16 +110,21 @@ def _remote_dataframe_fetcher(payload: RemoteDFPayload) -> RemoteDFPayload: return payload -def _maybe_submit_job(job: Union[Model, RecordHandler]) -> Optional[Union[Model, RecordHandler]]: +def _maybe_submit_job( + job: Union[Model, RecordHandler] +) -> Optional[Union[Model, RecordHandler]]: try: job = job.submit_cloud() except ApiException as err: if "Maximum number of" in str(err): - logger.warning("Rate limiting: Max jobs created, skipping new job for now...") + logger.warning( + "Rate limiting: Max jobs created, skipping new job for now..." + ) return None - + return job + class StrategyRunner: _df: pd.DataFrame @@ -186,7 +192,9 @@ def load(self): self._loaded = True @classmethod - def from_completed(cls, project: Project, cache_file: Union[str, Path]) -> StrategyRunner: + def from_completed( + cls, project: Project, cache_file: Union[str, Path] + ) -> StrategyRunner: cache_file = Path(cache_file) if not cache_file.exists(): raise ValueError("cache file does not exist") @@ -235,6 +243,13 @@ def _update_job_status(self): if current_model.status == Status.COMPLETED: report = current_model.peek_report() + + if report is None: + with smart_open.open( + current_model.get_artifact_link("report_json") + ) as fin: + report = json.loads(fin.read()) + sqs = report["synthetic_data_quality_score"]["score"] label = "Moderate" if sqs >= 80: @@ -248,7 +263,7 @@ def _update_job_status(self): ) _update.update({SQS: report}) - + partition.update_ctx(_update) self._strategy.status_counter = dict(self._status_counter) # Aggressive, but save after every update @@ -277,7 +292,8 @@ def _update_handler_status(self): # Hydrate a Model and Handler object from the remote API current_model = Model(self._project, model_id=model_id) - current_handler = RecordHandler(model=current_model, record_id=handler_id) + current_handler = RecordHandler( + model=current_model, record_id=handler_id) self._handler_status_counter.update([current_handler.status]) @@ -331,7 +347,8 @@ def _remove_unused_artifact(self) -> Optional[str]: continue if artifact_key: - logger.debug(f"Attempting to remove artifact: {p.ctx.get(ARTIFACT)}") + logger.debug( + f"Attempting to remove artifact: {p.ctx.get(ARTIFACT)}") self._project.delete_artifact(artifact_key) p.update_ctx({ARTIFACT: None}) self._strategy.save() @@ -370,7 +387,8 @@ def _remove_unused_artifact(self) -> Optional[str]: def _partition_to_artifact(self, partition: Partition) -> Optional[ArtifactResult]: removed_artifact = self._remove_unused_artifact() if not removed_artifact: - logger.debug("Could not make room for next data set, waiting for room...") + logger.debug( + "Could not make room for next data set, waiting for room...") # We couldn't make room so we don't try and upload the next artifact return None @@ -389,27 +407,36 @@ def _df_to_artifact(self, df: pd.DataFrame, filename: str) -> ArtifactResult: return ArtifactResult(id=artifact_id, record_count=len(df)) @_needs_load - def train_partition(self, partition: Partition, artifact: ArtifactResult) -> Optional[str]: + def train_partition( + self, partition: Partition, artifact: ArtifactResult + ) -> Optional[str]: attempt = partition.ctx.get(ATTEMPT, 0) + 1 model_config = deepcopy(self._model_config) - model_config["models"][0]["synthetics"]["generate"] = { - "num_records": artifact.record_count, - "max_invalid": None, - } + data_source = None + + if "synthetics" in model_config["models"][0].keys(): - # If we're trying this model for a second+ time, we reduce the vocab size to - # utilize the char encoder in order to give a better chance and success - if attempt > 1: - model_config["models"][0]["synthetics"]["params"]["vocab_size"] = 0 - - # If this partition is for the first-N headers and we have known seed headers, we have to - # modify the configuration to account for the seed task. - if partition.columns.seed_headers: - model_config["models"][0]["synthetics"]["task"] = { - "type": "seed", - "attrs": {"fields": partition.columns.seed_headers}, + model_config["models"][0]["synthetics"]["generate"] = { + "num_records": artifact.record_count, + "max_invalid": None, } + # If we're trying this model for a second+ time, we reduce the vocab size to + # utilize the char encoder in order to give a better chance and success + if attempt > 1: + model_config["models"][0]["synthetics"]["params"]["vocab_size"] = 0 + + # If this partition is for the first-N headers and we have known seed headers, we have to + # modify the configuration to account for the seed task. + if partition.columns.seed_headers: + model_config["models"][0]["synthetics"]["task"] = { + "type": "seed", + "attrs": {"fields": partition.columns.seed_headers}, + } + + elif "ctgan" in model_config["models"][0].keys(): + pass + model = self._project.create_model_obj( model_config=model_config, data_source=artifact.id ) @@ -428,13 +455,14 @@ def train_partition(self, partition: Partition, artifact: ArtifactResult) -> Opt ) self._strategy.save() logger.info( - f"Started model: {model.print_obj['model_name']} " - f"source: {model.print_obj['config']['models'][0]['synthetics']['data_source']}" + f"Started model: {model.print_obj['model_name']} " f"source: {artifact.id}" ) return model.model_id @_needs_load - def run_partition(self, partition: Partition, gen_payload: GenPayload) -> Optional[str]: + def run_partition( + self, partition: Partition, gen_payload: GenPayload + ) -> Optional[str]: """ Run a record handler for a model and return the job id. @@ -456,7 +484,7 @@ def run_partition(self, partition: Partition, gen_payload: GenPayload) -> Option params={ "num_records": gen_payload.num_records, "max_invalid": gen_payload.max_invalid, - } + }, ) handler_obj = _maybe_submit_job(handler_obj) if handler_obj is None: @@ -517,7 +545,10 @@ def run_next_partition(self, gen_payload: GenPayload) -> Optional[str]: logger.info(f"Generating data for partition {partition.idx}") start_job = True - elif status in (Status.ERROR, Status.LOST) and attempt_count < self._error_retry_limit: + elif ( + status in (Status.ERROR, Status.LOST) + and attempt_count < self._error_retry_limit + ): logger.info( f"Partition {partition.idx} has status {status.value}, re-attempting generation" ) @@ -531,13 +562,17 @@ def run_next_partition(self, gen_payload: GenPayload) -> Optional[str]: gen_payload.seed_df, pd.DataFrame ): # NOTE(jm): If we've tried N-1 attempts with seeds and the handler has continued - # fail then we should not use seeds to at least let the handler try to succeed. + # fail then we should not use seeds to at least let the handler try to succeed. # One example of this happening would be when a partition's model receives seeds # where the values of the seeds were not in the training set (due to partitioning). if attempt_count == self._error_retry_limit - 1: - logger.info(f"WARNING: Disabling seeds for partition {partition.idx} due to previous failed generation attempts...") + logger.info( + f"WARNING: Disabling seeds for partition {partition.idx} due to previous failed generation attempts..." + ) else: - logger.info("Partition has seed fields, uploading seed artifact...") + logger.info( + "Partition has seed fields, uploading seed artifact..." + ) use_seeds = True removed_artifact = self._remove_unused_artifact() if removed_artifact is None: @@ -547,7 +582,8 @@ def run_next_partition(self, gen_payload: GenPayload) -> Optional[str]: return None filename = f"{self.strategy_id}-seeds-{partition.idx}.csv" - artifact = self._df_to_artifact(gen_payload.seed_df, filename) + artifact = self._df_to_artifact( + gen_payload.seed_df, filename) new_payload = GenPayload( num_records=gen_payload.num_records, @@ -628,7 +664,8 @@ def _get_synthetic_data(self, job_type: str, artifact_type: str) -> pd.DataFrame num_completed = self._status_counter.get(Status.COMPLETED, 0) elif job_type == "run": self._update_handler_status() - num_completed = self._handler_status_counter.get(Status.COMPLETED, 0) + num_completed = self._handler_status_counter.get( + Status.COMPLETED, 0) else: raise ValueError("invalid job_type") @@ -697,10 +734,12 @@ def generate_data( seed_df: Optional[pd.DataFrame] = None, num_records: Optional[int] = None, max_invalid: Optional[int] = None, - clear_cache: bool = False + clear_cache: bool = False, ): + if seed_df is None and not num_records: - raise ValueError("must provide a seed_df or num_records to generate") + raise ValueError( + "must provide a seed_df or num_records to generate") if isinstance(seed_df, pd.DataFrame) and num_records: raise ValueError("must use one of seed_df or num_records only") @@ -769,7 +808,7 @@ def generate_data( if handler_id is not None: # Go around and try again if this succeeded continue - + # Catch a 4xx in the event we are at capacity, or something else goes wrong except Exception as err: logger.warning(f"Error running model: {str(err)}") diff --git a/src/trainer/strategy.py b/src/gretel_trainer/strategy.py similarity index 90% rename from src/trainer/strategy.py rename to src/gretel_trainer/strategy.py index 72ba50b8..cf7fc1e4 100644 --- a/src/trainer/strategy.py +++ b/src/gretel_trainer/strategy.py @@ -32,7 +32,7 @@ def extract_df(self, df: pd.DataFrame) -> pd.DataFrame: if self.columns is not None: df = df[self.columns.headers] - return df.iloc[self.rows.start : self.rows.end] # noqa + return df.iloc[self.rows.start: self.rows.end] # noqa def update_ctx(self, update: dict): self.ctx.update(update) @@ -47,10 +47,12 @@ class PartitionConstraints: def __post_init__(self): if self.max_row_count is not None and self.max_row_partitions is not None: - raise AttributeError("cannot use both max_row_count and max_row_partitions") + raise AttributeError( + "cannot use both max_row_count and max_row_partitions") if self.max_row_count is None and self.max_row_partitions is None: - raise AttributeError("must use one of max_row_count or max_row_partitions") + raise AttributeError( + "must use one of max_row_count or max_row_partitions") @property def header_cluster_count(self) -> int: @@ -82,7 +84,9 @@ def _build_partitions( partitions.append( Partition( rows=RowPartition(start=next_start, end=next_end), - columns=ColumnPartition(headers=header_cluster, idx=idx, seed_headers=seed_headers), + columns=ColumnPartition( + headers=header_cluster, idx=idx, seed_headers=seed_headers + ), idx=partition_idx, ) ) @@ -114,7 +118,9 @@ def _build_partitions( rows=RowPartition( start=curr_start, end=curr_start + chunk_size ), - columns=ColumnPartition(headers=header_cluster, idx=idx, seed_headers=seed_headers), + columns=ColumnPartition( + headers=header_cluster, idx=idx, seed_headers=seed_headers + ), idx=partition_idx, ) ) @@ -141,7 +147,7 @@ def from_dataframe( id=id, partitions=partitions, header_cluster_count=constraints.header_cluster_count, - original_headers=list(df) + original_headers=list(df), ) @classmethod diff --git a/src/gretel_trainer/trainer.py b/src/gretel_trainer/trainer.py new file mode 100644 index 00000000..795b35ac --- /dev/null +++ b/src/gretel_trainer/trainer.py @@ -0,0 +1,210 @@ +"""Main Trainer Module""" + +import json +import logging +import os.path +from collections import namedtuple +from enum import Enum + +import pandas as pd +from gretel_client import ClientConfig, configure_session +from gretel_client.projects import create_or_get_unique_project +from gretel_client.projects.jobs import Status +from gretel_client.projects.models import read_model_config +from gretel_synthetics.utils.header_clusters import cluster + +from gretel_trainer import runner, strategy + +logger = logging.getLogger(__name__) +logger.setLevel("DEBUG") + +DEFAULT_PROJECT = "trainer" +DEFAULT_CACHE = f"{DEFAULT_PROJECT}-runner.json" + + +class ExtendedEnum(Enum): + """Utility class for Model enum""" + + @classmethod + def get_types(cls): + return list(map(lambda c: c.name, cls)) + + @classmethod + def get_config(cls, model: str): + return cls[model].config + + +class Model(namedtuple("Model", "config"), ExtendedEnum): + """Enum to pair valid models and configurations""" + + GretelLSTM = "synthetics/default" + GretelCTGAN = "synthetics/ctgan" + + +class Trainer: + """Automated model training and synthetic data generation tool + + Args: + project_name (str, optional): Gretel project name. Defaults to "trainer". + max_header_clusters (int, optional): Max number of clusters per batch. Defaults to 20. + max_rows (int, optional): Max number of rows per batch. Defaults to 50000. + model_type (str, optional): Options include ["GretelLSTM", "GretelCTGAN"]. Defaults to "GretelLSTM". + model_params (dict, optional): Modify model configuration settings by key. E.g. {'epochs': 20} + cache_file (str, optional): Select a path to save or load the cache file. Default is `[project_name]-runner.json`. + overwrite (bool, optional): Overwrite previous progress. Defaults to True. + """ + + def __init__( + self, + project_name: str = "trainer", + max_header_clusters: int = 20, + max_rows: int = 50000, + model_type: str = "GretelLSTM", + model_params: dict = {}, + cache_file: str = None, + overwrite: bool = True, + ): + + configure_session(api_key="prompt", cache="yes", validate=True) + + self.df = None + self.dataset_path = None + self.run = None + self.project_name = project_name + self.project = create_or_get_unique_project(name=project_name) + self.max_header_clusters = max_header_clusters + self.max_rows = max_rows + self.overwrite = overwrite + self.cache_file = self._get_cache_file(cache_file) + + if model_type in Model.get_types(): + self.config = read_model_config(Model.get_config(model_type)) + + # Update default config settings with params by key + for key, value in model_params.items(): + self.config = self._replace_nested_key(self.config, key, value) + else: + raise ValueError( + f"Invalid model type. Must be {Model.get_model_types()}") + + if self.overwrite: + logger.debug(json.dumps(self.config, indent=2)) + + @classmethod + def load( + cls, cache_file: str = DEFAULT_CACHE, project_name: str = DEFAULT_PROJECT + ) -> runner.StrategyRunner: + """Load an existing project from a cache. + + Args: + cache_file (str, optional): Valid file path to load the cache file from. Defaults to `[project-name]-runner.json` + + Returns: + Trainer: returns an initialized StrategyRunner class. + """ + project = create_or_get_unique_project(name=project_name) + model = cls(cache_file=cache_file, + project_name=project_name, overwrite=False) + + if not os.path.exists(cache_file): + raise ValueError( + f"Unable to find `{cache_file}`. Please specify a valid cache_file." + ) + + model.run = model._initialize_run(df=None, overwrite=model.overwrite) + return model + + def train(self, dataset_path: str, round_decimals: int = 4): + """Train a model on the dataset + + Args: + dataset_path (str): Path or URL to CSV + round_decimals (int, optional): Round decimals in CSV as preprocessing step. Defaults to `4`. + """ + self.dataset_path = dataset_path + self.df = self._preprocess_data( + dataset_path=dataset_path, round_decimals=round_decimals + ) + self.run = self._initialize_run(df=self.df, overwrite=self.overwrite) + self.run.train_all_partitions() + + def generate(self, num_records: int = 500) -> pd.DataFrame: + """Generate synthetic data + + Args: + num_records (int, optional): Number of records to generate from model. Defaults to 500. + + Returns: + pd.DataFrame: Synthetic data. + """ + self.run.generate_data( + num_records=num_records, max_invalid=None, clear_cache=True + ) + return self.run.get_synthetic_data() + + def _replace_nested_key(self, data, key, value): + """Replace nested keys""" + if isinstance(data, dict): + return { + k: value if k == key else self._replace_nested_key( + v, key, value) + for k, v in data.items() + } + elif isinstance(data, list): + return [self._replace_nested_key(v, key, value) for v in data] + else: + return data + + def _preprocess_data( + self, dataset_path: str, round_decimals: int = 4 + ) -> pd.DataFrame: + """Preprocess input data""" + tmp = pd.read_csv(dataset_path, low_memory=False) + tmp = tmp.round(round_decimals) + return tmp + + def _get_cache_file(self, cache_file: str) -> str: + """Select a path to store the runtime cache to initialize a model""" + if cache_file is None: + cache_file = f"{self.project_name}-runner.json" + + if os.path.exists(cache_file): + if self.overwrite: + logger.warning( + f"Overwriting existing run cache: {cache_file}.") + else: + logger.info(f"Using existing run cache: {cache_file}.") + else: + logger.info(f"Creating new run cache: {cache_file}.") + return cache_file + + def _initialize_run( + self, df: pd.DataFrame = None, overwrite: bool = True + ) -> runner.StrategyRunner: + """Create training jobs""" + constraints = None + if df is None: + df = pd.DataFrame() + + if not df.empty: + header_clusters = cluster( + df, maxsize=self.max_header_clusters, plot=False) + logger.info( + f"Header clustering created {len(header_clusters)} cluster(s) " + f"of length(s) {[len(x) for x in header_clusters]}" + ) + + constraints = strategy.PartitionConstraints( + header_clusters=header_clusters, max_row_count=self.max_rows + ) + + run = runner.StrategyRunner( + strategy_id=f"{self.project_name}", + df=self.df, + cache_file=self.cache_file, + cache_overwrite=overwrite, + model_config=self.config, + partition_constraints=constraints, + project=self.project, + ) + return run diff --git a/tests/test_strategy.py b/tests/test_strategy.py index 8440f816..74d2e990 100644 --- a/tests/test_strategy.py +++ b/tests/test_strategy.py @@ -6,7 +6,7 @@ import pytest from gretel_synthetics.utils.header_clusters import cluster -from trainer.strategy import PartitionConstraints, PartitionStrategy +from gretel_trainer.strategy import PartitionConstraints, PartitionStrategy @pytest.fixture(scope="module", autouse=True)