Skip to content

Commit

Permalink
Remove column partitioning (#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeknep authored Aug 1, 2023
1 parent 171fa65 commit f5baf54
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 148 deletions.
5 changes: 1 addition & 4 deletions notebooks/custom-example.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from gretel_client import configure_session

from gretel_trainer import Trainer
from gretel_trainer.models import GretelLSTM, GretelACTGAN

from gretel_trainer.models import GretelACTGAN, GretelLSTM

# Configure Gretel credentials
configure_session(api_key="prompt", cache="yes", validate=True)
Expand All @@ -13,7 +11,6 @@
# configs can be either a string, dict, or path
model_type = GretelACTGAN(
config="synthetics/tabular-actgan",
max_header_clusters=100,
max_rows=50000
)

Expand Down
1 change: 0 additions & 1 deletion notebooks/trainer-examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
"\n",
"model_type = GretelACTGAN(\n",
" config=\"synthetics/tabular-actgan\", \n",
" max_header_clusters=100, \n",
" max_rows=50000\n",
")\n",
"\n",
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
boto3~=1.20
dask[complete]==2023.5.1
dask[dataframe]==2023.5.1
gretel-client>=0.16.0
gretel-synthetics[utils]
jinja2~=3.1
networkx~=3.0
numpy~=1.20
Expand Down
38 changes: 17 additions & 21 deletions src/gretel_trainer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,20 @@ class _BaseConfig:
"""

"""This should be overridden on concrete classes"""
_max_header_clusters_limit: int
_max_rows_limit: int
_model_slug: str

# Should be set by concrete constructors
config: dict
max_rows: int
max_header_clusters: int

def __init__(
self,
config: Union[str, dict],
max_rows: int,
max_header_clusters: int,
):
self.config = read_model_config(config)
self.max_rows = max_rows
self.max_header_clusters = max_header_clusters

self.validate()

Expand All @@ -89,11 +85,6 @@ def validate(self):
f"max_rows must be less than {self._max_rows_limit} for this model type."
)

if self.max_header_clusters > self._max_header_clusters_limit:
raise ValueError(
f"max_header_clusters must be less than {self._max_header_clusters_limit} for this model type."
)

def _replace_nested_key(self, data, key, value) -> dict:
"""Replace nested keys"""
if isinstance(data, dict):
Expand All @@ -114,23 +105,22 @@ class GretelLSTM(_BaseConfig):
Args:
config (str/dict, optional): Either a string representing the path to the config on the local filesystem, a string representing a path to the default Gretel configurations, or a dictionary containing the configurations. Default: "synthetics/tabular-lstm", a default Gretel configuration
max_rows (int, optional): The number of rows of synthetic data to generate. Defaults to 50000
max_header_clusters (int, optional): Default: 20
max_header_clusters (int, optional): This parameter is deprecated and will be removed in a future release.
"""

_max_header_clusters_limit: int = 30
_max_rows_limit: int = 5_000_000
_model_slug: str = "synthetics"

def __init__(
self,
config="synthetics/tabular-lstm",
max_rows=50_000,
max_header_clusters=20,
max_header_clusters=None,
):
_max_header_clusters_deprecation_warning(max_header_clusters)
super().__init__(
config=config,
max_rows=max_rows,
max_header_clusters=max_header_clusters,
)


Expand All @@ -143,23 +133,22 @@ class GretelACTGAN(_BaseConfig):
Args:
config (str/dict, optional): Either a string representing the path to the config on the local filesystem, a string representing a path to the default Gretel configurations, or a dictionary containing the configurations. Default: "synthetics/tabular-actgan", a default Gretel configuration
max_rows (int, optional): The number of rows of synthetic data to generate. Defaults to 50000
max_header_clusters (int, optional): Default: 500
max_header_clusters (int, optional): This parameter is deprecated and will be removed in a future release.
"""

_max_header_clusters_limit: int = 5_000
_max_rows_limit: int = 5_000_000
_model_slug: str = "actgan"

def __init__(
self,
config="synthetics/tabular-actgan",
max_rows=1_000_000,
max_header_clusters=1_000,
max_header_clusters=None,
):
_max_header_clusters_deprecation_warning(max_header_clusters)
super().__init__(
config=config,
max_rows=max_rows,
max_header_clusters=max_header_clusters,
)


Expand All @@ -172,21 +161,28 @@ class GretelAmplify(_BaseConfig):
Args:
config (str/dict, optional): Either a string representing the path to the config on the local filesystem, a string representing a path to the default Gretel configurations, or a dictionary containing the configurations. Default: "synthetics/amplify", a default Gretel configuration for Amplify.
max_rows (int, optional): The number of rows of synthetic data to generate. Defaults to 50000
max_header_clusters (int, optional): Default: 50
max_header_clusters (int, optional): This parameter is deprecated and will be removed in a future release.
"""

_max_header_clusters_limit: int = 1_000
_max_rows_limit: int = 1_000_000_000
_model_slug: str = "amplify"

def __init__(
self,
config="synthetics/amplify",
max_rows=50_000,
max_header_clusters=500,
max_header_clusters=None,
):
_max_header_clusters_deprecation_warning(max_header_clusters)
super().__init__(
config=config,
max_rows=max_rows,
max_header_clusters=max_header_clusters,
)


def _max_header_clusters_deprecation_warning(value: Optional[int]) -> None:
if value is not None:
logger.warning(
"Trainer no longer performs header clustering. "
"The max_header_clusters parameter is deprecated and will be removed in a future release."
)
14 changes: 2 additions & 12 deletions src/gretel_trainer/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ class GenPayload:
@dataclass
class RemoteDFPayload:
partition: int
slot: int
job_type: str
uid: Optional[str]
handler_uid: Optional[str]
Expand Down Expand Up @@ -634,11 +633,7 @@ def _get_synthetic_data(self, job_type: str, artifact_type: str) -> pd.DataFrame
"Not all partitions are completed, cannot fetch synthetic data from trained models"
)

# We will have at least one column-wise DF, this holds
# one DF for each header cluster we have
df_chunks = {
i: pd.DataFrame() for i in range(0, self._strategy.header_cluster_count)
}
df = pd.DataFrame()

pool = ThreadPoolExecutor()
futures = []
Expand All @@ -648,7 +643,6 @@ def _get_synthetic_data(self, job_type: str, artifact_type: str) -> pd.DataFrame
# ones they need to use.
payload = RemoteDFPayload(
partition=partition.idx,
slot=partition.columns.idx,
job_type=job_type,
handler_uid=partition.ctx.get(HANDLER, {}).get(HANDLER_ID),
uid=partition.ctx.get(MODEL_ID),
Expand All @@ -662,12 +656,8 @@ def _get_synthetic_data(self, job_type: str, artifact_type: str) -> pd.DataFrame
for future in futures:
payload, this_df = future.result()

curr_df = df_chunks[payload.slot]
df_chunks[payload.slot] = pd.concat([curr_df, this_df]).reset_index(
drop=True
)
df = pd.concat([df, this_df]).reset_index(drop=True)

df = pd.concat(list(df_chunks.values()), axis=1)
return df

def _maybe_restore_df_headers(self, df) -> pd.DataFrame:
Expand Down
41 changes: 13 additions & 28 deletions src/gretel_trainer/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class RowPartition(BaseModel):
class ColumnPartition(BaseModel):
headers: Optional[List[str]]
seed_headers: Optional[List[str]]
idx: int


class Partition(BaseModel):
Expand All @@ -42,25 +41,14 @@ def update_ctx(self, update: dict):
@dataclass
class PartitionConstraints:
max_row_count: int
header_clusters: Optional[List[List[str]]] = None
seed_headers: Optional[List[str]] = None

@property
def header_cluster_count(self) -> int:
if self.header_clusters is None:
return 1
return len(self.header_clusters)


def _build_partitions(
df: pd.DataFrame, constraints: PartitionConstraints
) -> List[Partition]:
total_rows = len(df)

header_clusters = constraints.header_clusters
if header_clusters is None:
header_clusters = [list(df.columns)]

partitions = []
partition_idx = 0
partition_count = math.ceil(total_rows / constraints.max_row_count)
Expand All @@ -77,20 +65,19 @@ def _build_partitions(

curr_start = 0
for chunk_size in chunks:
for idx, header_cluster in enumerate(header_clusters):
seed_headers = constraints.seed_headers if idx == 0 else None
partitions.append(
Partition(
rows=RowPartition(
start=curr_start, end=curr_start + chunk_size
),
columns=ColumnPartition(
headers=header_cluster, idx=idx, seed_headers=seed_headers
),
idx=partition_idx,
)
seed_headers = constraints.seed_headers
partitions.append(
Partition(
rows=RowPartition(
start=curr_start, end=curr_start + chunk_size
),
columns=ColumnPartition(
headers=list(df.columns), seed_headers=seed_headers
),
idx=partition_idx,
)
partition_idx += 1
)
partition_idx += 1
curr_start += chunk_size

return partitions
Expand All @@ -99,7 +86,6 @@ def _build_partitions(
class PartitionStrategy(BaseModel):
id: str
partitions: List[Partition]
header_cluster_count: int
original_headers: Optional[List[str]]
status_counter: Optional[dict]
_disk_location: Path = PrivateAttr(default=None)
Expand All @@ -112,7 +98,6 @@ def from_dataframe(
return cls(
id=id,
partitions=partitions,
header_cluster_count=constraints.header_cluster_count,
original_headers=list(df.columns),
status_counter=None,
)
Expand All @@ -135,7 +120,7 @@ def partition_count(self) -> int:

@property
def row_partition_count(self) -> int:
return math.ceil(len(self.partitions) / self.header_cluster_count)
return len(self.partitions)

def save_to(self, dest: Union[Path, str], overwrite: bool = False):
location = Path(dest)
Expand Down
13 changes: 0 additions & 13 deletions src/gretel_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import pandas as pd
from gretel_client.config import get_session_config, RunnerMode
from gretel_client.projects import create_or_get_unique_project
from gretel_synthetics.utils.header_clusters import cluster

from gretel_trainer import runner, strategy
from gretel_trainer.models import _BaseConfig, determine_best_model
Expand Down Expand Up @@ -206,19 +205,7 @@ def _initialize_run(

model_config = self.model_type.config

header_clusters = cluster(
df,
maxsize=self.model_type.max_header_clusters,
header_prefix=seed_fields,
plot=False,
)
logger.info(
f"Header clustering created {len(header_clusters)} cluster(s) "
f"of length(s) {[len(x) for x in header_clusters]}"
)

constraints = strategy.PartitionConstraints(
header_clusters=header_clusters,
max_row_count=self.model_type.max_rows,
seed_headers=seed_fields,
)
Expand Down
Loading

0 comments on commit f5baf54

Please sign in to comment.