From f5baf542c65c1dc3efbabd00dc3567e1fbb1ccee Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Tue, 1 Aug 2023 08:17:33 -0500 Subject: [PATCH] Remove column partitioning (#144) --- notebooks/custom-example.py | 5 +-- notebooks/trainer-examples.ipynb | 1 - requirements.txt | 3 +- src/gretel_trainer/models.py | 38 ++++++++-------- src/gretel_trainer/runner.py | 14 +----- src/gretel_trainer/strategy.py | 41 ++++++------------ src/gretel_trainer/trainer.py | 13 ------ tests/test_strategy.py | 74 +++----------------------------- 8 files changed, 41 insertions(+), 148 deletions(-) diff --git a/notebooks/custom-example.py b/notebooks/custom-example.py index b4ae3789..cdf1ea7a 100644 --- a/notebooks/custom-example.py +++ b/notebooks/custom-example.py @@ -1,8 +1,6 @@ from gretel_client import configure_session - from gretel_trainer import Trainer -from gretel_trainer.models import GretelLSTM, GretelACTGAN - +from gretel_trainer.models import GretelACTGAN, GretelLSTM # Configure Gretel credentials configure_session(api_key="prompt", cache="yes", validate=True) @@ -13,7 +11,6 @@ # configs can be either a string, dict, or path model_type = GretelACTGAN( config="synthetics/tabular-actgan", - max_header_clusters=100, max_rows=50000 ) diff --git a/notebooks/trainer-examples.ipynb b/notebooks/trainer-examples.ipynb index 430ecd2b..8642fb2d 100644 --- a/notebooks/trainer-examples.ipynb +++ b/notebooks/trainer-examples.ipynb @@ -54,7 +54,6 @@ "\n", "model_type = GretelACTGAN(\n", " config=\"synthetics/tabular-actgan\", \n", - " max_header_clusters=100, \n", " max_rows=50000\n", ")\n", "\n", diff --git a/requirements.txt b/requirements.txt index 9dff7157..2062012f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ boto3~=1.20 -dask[complete]==2023.5.1 +dask[dataframe]==2023.5.1 gretel-client>=0.16.0 -gretel-synthetics[utils] jinja2~=3.1 networkx~=3.0 numpy~=1.20 diff --git a/src/gretel_trainer/models.py b/src/gretel_trainer/models.py index d4b1d244..b2d5356a 100644 --- a/src/gretel_trainer/models.py +++ b/src/gretel_trainer/models.py @@ -49,24 +49,20 @@ class _BaseConfig: """ """This should be overridden on concrete classes""" - _max_header_clusters_limit: int _max_rows_limit: int _model_slug: str # Should be set by concrete constructors config: dict max_rows: int - max_header_clusters: int def __init__( self, config: Union[str, dict], max_rows: int, - max_header_clusters: int, ): self.config = read_model_config(config) self.max_rows = max_rows - self.max_header_clusters = max_header_clusters self.validate() @@ -89,11 +85,6 @@ def validate(self): f"max_rows must be less than {self._max_rows_limit} for this model type." ) - if self.max_header_clusters > self._max_header_clusters_limit: - raise ValueError( - f"max_header_clusters must be less than {self._max_header_clusters_limit} for this model type." - ) - def _replace_nested_key(self, data, key, value) -> dict: """Replace nested keys""" if isinstance(data, dict): @@ -114,10 +105,9 @@ class GretelLSTM(_BaseConfig): Args: config (str/dict, optional): Either a string representing the path to the config on the local filesystem, a string representing a path to the default Gretel configurations, or a dictionary containing the configurations. Default: "synthetics/tabular-lstm", a default Gretel configuration max_rows (int, optional): The number of rows of synthetic data to generate. Defaults to 50000 - max_header_clusters (int, optional): Default: 20 + max_header_clusters (int, optional): This parameter is deprecated and will be removed in a future release. """ - _max_header_clusters_limit: int = 30 _max_rows_limit: int = 5_000_000 _model_slug: str = "synthetics" @@ -125,12 +115,12 @@ def __init__( self, config="synthetics/tabular-lstm", max_rows=50_000, - max_header_clusters=20, + max_header_clusters=None, ): + _max_header_clusters_deprecation_warning(max_header_clusters) super().__init__( config=config, max_rows=max_rows, - max_header_clusters=max_header_clusters, ) @@ -143,10 +133,9 @@ class GretelACTGAN(_BaseConfig): Args: config (str/dict, optional): Either a string representing the path to the config on the local filesystem, a string representing a path to the default Gretel configurations, or a dictionary containing the configurations. Default: "synthetics/tabular-actgan", a default Gretel configuration max_rows (int, optional): The number of rows of synthetic data to generate. Defaults to 50000 - max_header_clusters (int, optional): Default: 500 + max_header_clusters (int, optional): This parameter is deprecated and will be removed in a future release. """ - _max_header_clusters_limit: int = 5_000 _max_rows_limit: int = 5_000_000 _model_slug: str = "actgan" @@ -154,12 +143,12 @@ def __init__( self, config="synthetics/tabular-actgan", max_rows=1_000_000, - max_header_clusters=1_000, + max_header_clusters=None, ): + _max_header_clusters_deprecation_warning(max_header_clusters) super().__init__( config=config, max_rows=max_rows, - max_header_clusters=max_header_clusters, ) @@ -172,10 +161,9 @@ class GretelAmplify(_BaseConfig): Args: config (str/dict, optional): Either a string representing the path to the config on the local filesystem, a string representing a path to the default Gretel configurations, or a dictionary containing the configurations. Default: "synthetics/amplify", a default Gretel configuration for Amplify. max_rows (int, optional): The number of rows of synthetic data to generate. Defaults to 50000 - max_header_clusters (int, optional): Default: 50 + max_header_clusters (int, optional): This parameter is deprecated and will be removed in a future release. """ - _max_header_clusters_limit: int = 1_000 _max_rows_limit: int = 1_000_000_000 _model_slug: str = "amplify" @@ -183,10 +171,18 @@ def __init__( self, config="synthetics/amplify", max_rows=50_000, - max_header_clusters=500, + max_header_clusters=None, ): + _max_header_clusters_deprecation_warning(max_header_clusters) super().__init__( config=config, max_rows=max_rows, - max_header_clusters=max_header_clusters, + ) + + +def _max_header_clusters_deprecation_warning(value: Optional[int]) -> None: + if value is not None: + logger.warning( + "Trainer no longer performs header clustering. " + "The max_header_clusters parameter is deprecated and will be removed in a future release." ) diff --git a/src/gretel_trainer/runner.py b/src/gretel_trainer/runner.py index 18166b83..fdd539f0 100644 --- a/src/gretel_trainer/runner.py +++ b/src/gretel_trainer/runner.py @@ -77,7 +77,6 @@ class GenPayload: @dataclass class RemoteDFPayload: partition: int - slot: int job_type: str uid: Optional[str] handler_uid: Optional[str] @@ -634,11 +633,7 @@ def _get_synthetic_data(self, job_type: str, artifact_type: str) -> pd.DataFrame "Not all partitions are completed, cannot fetch synthetic data from trained models" ) - # We will have at least one column-wise DF, this holds - # one DF for each header cluster we have - df_chunks = { - i: pd.DataFrame() for i in range(0, self._strategy.header_cluster_count) - } + df = pd.DataFrame() pool = ThreadPoolExecutor() futures = [] @@ -648,7 +643,6 @@ def _get_synthetic_data(self, job_type: str, artifact_type: str) -> pd.DataFrame # ones they need to use. payload = RemoteDFPayload( partition=partition.idx, - slot=partition.columns.idx, job_type=job_type, handler_uid=partition.ctx.get(HANDLER, {}).get(HANDLER_ID), uid=partition.ctx.get(MODEL_ID), @@ -662,12 +656,8 @@ def _get_synthetic_data(self, job_type: str, artifact_type: str) -> pd.DataFrame for future in futures: payload, this_df = future.result() - curr_df = df_chunks[payload.slot] - df_chunks[payload.slot] = pd.concat([curr_df, this_df]).reset_index( - drop=True - ) + df = pd.concat([df, this_df]).reset_index(drop=True) - df = pd.concat(list(df_chunks.values()), axis=1) return df def _maybe_restore_df_headers(self, df) -> pd.DataFrame: diff --git a/src/gretel_trainer/strategy.py b/src/gretel_trainer/strategy.py index e7454766..4d90b401 100644 --- a/src/gretel_trainer/strategy.py +++ b/src/gretel_trainer/strategy.py @@ -20,7 +20,6 @@ class RowPartition(BaseModel): class ColumnPartition(BaseModel): headers: Optional[List[str]] seed_headers: Optional[List[str]] - idx: int class Partition(BaseModel): @@ -42,25 +41,14 @@ def update_ctx(self, update: dict): @dataclass class PartitionConstraints: max_row_count: int - header_clusters: Optional[List[List[str]]] = None seed_headers: Optional[List[str]] = None - @property - def header_cluster_count(self) -> int: - if self.header_clusters is None: - return 1 - return len(self.header_clusters) - def _build_partitions( df: pd.DataFrame, constraints: PartitionConstraints ) -> List[Partition]: total_rows = len(df) - header_clusters = constraints.header_clusters - if header_clusters is None: - header_clusters = [list(df.columns)] - partitions = [] partition_idx = 0 partition_count = math.ceil(total_rows / constraints.max_row_count) @@ -77,20 +65,19 @@ def _build_partitions( curr_start = 0 for chunk_size in chunks: - for idx, header_cluster in enumerate(header_clusters): - seed_headers = constraints.seed_headers if idx == 0 else None - partitions.append( - Partition( - rows=RowPartition( - start=curr_start, end=curr_start + chunk_size - ), - columns=ColumnPartition( - headers=header_cluster, idx=idx, seed_headers=seed_headers - ), - idx=partition_idx, - ) + seed_headers = constraints.seed_headers + partitions.append( + Partition( + rows=RowPartition( + start=curr_start, end=curr_start + chunk_size + ), + columns=ColumnPartition( + headers=list(df.columns), seed_headers=seed_headers + ), + idx=partition_idx, ) - partition_idx += 1 + ) + partition_idx += 1 curr_start += chunk_size return partitions @@ -99,7 +86,6 @@ def _build_partitions( class PartitionStrategy(BaseModel): id: str partitions: List[Partition] - header_cluster_count: int original_headers: Optional[List[str]] status_counter: Optional[dict] _disk_location: Path = PrivateAttr(default=None) @@ -112,7 +98,6 @@ def from_dataframe( return cls( id=id, partitions=partitions, - header_cluster_count=constraints.header_cluster_count, original_headers=list(df.columns), status_counter=None, ) @@ -135,7 +120,7 @@ def partition_count(self) -> int: @property def row_partition_count(self) -> int: - return math.ceil(len(self.partitions) / self.header_cluster_count) + return len(self.partitions) def save_to(self, dest: Union[Path, str], overwrite: bool = False): location = Path(dest) diff --git a/src/gretel_trainer/trainer.py b/src/gretel_trainer/trainer.py index f7908a32..b8355078 100644 --- a/src/gretel_trainer/trainer.py +++ b/src/gretel_trainer/trainer.py @@ -11,7 +11,6 @@ import pandas as pd from gretel_client.config import get_session_config, RunnerMode from gretel_client.projects import create_or_get_unique_project -from gretel_synthetics.utils.header_clusters import cluster from gretel_trainer import runner, strategy from gretel_trainer.models import _BaseConfig, determine_best_model @@ -206,19 +205,7 @@ def _initialize_run( model_config = self.model_type.config - header_clusters = cluster( - df, - maxsize=self.model_type.max_header_clusters, - header_prefix=seed_fields, - plot=False, - ) - logger.info( - f"Header clustering created {len(header_clusters)} cluster(s) " - f"of length(s) {[len(x) for x in header_clusters]}" - ) - constraints = strategy.PartitionConstraints( - header_clusters=header_clusters, max_row_count=self.model_type.max_rows, seed_headers=seed_fields, ) diff --git a/tests/test_strategy.py b/tests/test_strategy.py index 8c02a097..23673ce3 100644 --- a/tests/test_strategy.py +++ b/tests/test_strategy.py @@ -1,11 +1,9 @@ from pathlib import Path from typing import List -from dataclasses import dataclass import pandas as pd import pytest -from gretel_synthetics.utils.header_clusters import cluster from gretel_trainer.strategy import PartitionConstraints, PartitionStrategy @@ -15,23 +13,8 @@ def test_df() -> pd.DataFrame: @pytest.fixture(scope="module") -def header_clusters(test_df) -> List[List[str]]: - clusters = cluster(test_df) - assert len(clusters) == 2 - return clusters - - -@dataclass -class ClusterData: - clusters: List[List[str]] - seeds: List[str] - - -@pytest.fixture(scope="module") -def header_clusters_seed(test_df) -> ClusterData: - seeds = ["goal", "goal_type", "goals"] - clusters = cluster(test_df, header_prefix=seeds) - return ClusterData(clusters=clusters, seeds=seeds) +def test_seeds() -> List[str]: + return ["goal", "goal_type", "goals"] @pytest.mark.parametrize( @@ -64,60 +47,17 @@ def test_strategy_all_columns(constraints: PartitionConstraints, test_df): assert compare.shape == test_df.shape -@pytest.mark.parametrize( - "constraints", - [ - PartitionConstraints(max_row_count=1000), - PartitionConstraints(max_row_count=100), - ], -) -def test_strategy_column_batches( - constraints: PartitionConstraints, test_df, header_clusters -): - constraints.header_clusters = header_clusters - - strategy = PartitionStrategy.from_dataframe("foo", test_df, constraints) - assert ( - len(test_df) // constraints.max_row_count - <= strategy.partition_count / len(header_clusters) - <= len(test_df) // constraints.max_row_count + 1 - ) - - # partitions are of roughly equal size - extracted_df_lengths = [len(partition.extract_df(test_df)) for partition in strategy.partitions] - assert max(extracted_df_lengths) - min(extracted_df_lengths) <= 1 - - part1 = pd.DataFrame() - part2 = pd.DataFrame() - for idx, partition in enumerate(strategy.partitions): - assert partition.idx == idx - tmp_df = partition.extract_df(test_df) - if list(tmp_df.columns) == header_clusters[0]: - part1 = pd.concat([part1, tmp_df]).reset_index(drop=True) - else: - part2 = pd.concat([part2, tmp_df]).reset_index(drop=True) - - final = pd.concat([part1, part2], axis=1) - assert final.shape == test_df.shape - - -def test_strategy_seeds(test_df, header_clusters_seed: ClusterData): +def test_strategy_seeds(test_df, test_seeds): constraints = PartitionConstraints(max_row_count=100) - constraints.header_clusters = header_clusters_seed.clusters - constraints.seed_headers = header_clusters_seed.seeds + constraints.seed_headers = test_seeds strategy = PartitionStrategy.from_dataframe("foo", test_df, constraints) for partition in strategy.partitions: - if partition.columns.idx == 0: - assert partition.columns.seed_headers == header_clusters_seed.seeds - else: - assert not partition.columns.seed_headers + assert partition.columns.seed_headers == test_seeds -def test_read_write(test_df, header_clusters, tmpdir): +def test_read_write(test_df, tmpdir): save_location = Path(tmpdir) / "data.json" - constraints = PartitionConstraints( - max_row_count=100, header_clusters=header_clusters - ) + constraints = PartitionConstraints(max_row_count=100) strategy = PartitionStrategy.from_dataframe("foo", test_df, constraints) # Inproper filename