Skip to content

Commit

Permalink
Kill off all the old style types in Relational code (#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeknep authored May 18, 2023
1 parent 8ac8209 commit 7047e94
Show file tree
Hide file tree
Showing 23 changed files with 189 additions and 200 deletions.
14 changes: 7 additions & 7 deletions src/gretel_trainer/relational/ancestry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Dict, List, Optional, Tuple
from typing import Optional

import pandas as pd

Expand All @@ -13,16 +13,16 @@

def get_multigenerational_primary_key(
rel_data: RelationalData, table: str
) -> List[str]:
) -> list[str]:
return [
f"{_START_LINEAGE}{_END_LINEAGE}{pk}" for pk in rel_data.get_primary_key(table)
]


def get_ancestral_foreign_key_maps(
rel_data: RelationalData, table: str
) -> List[Tuple[str, str]]:
def _ancestral_fk_map(fk: ForeignKey) -> List[Tuple[str, str]]:
) -> list[tuple[str, str]]:
def _ancestral_fk_map(fk: ForeignKey) -> list[tuple[str, str]]:
maps = []
fk_lineage = _COL_DELIMITER.join(fk.columns)

Expand All @@ -49,7 +49,7 @@ def _ancestral_fk_map(fk: ForeignKey) -> List[Tuple[str, str]]:
def get_table_data_with_ancestors(
rel_data: RelationalData,
table: str,
tableset: Optional[Dict[str, pd.DataFrame]] = None,
tableset: Optional[dict[str, pd.DataFrame]] = None,
ancestral_seeding: bool = False,
) -> pd.DataFrame:
"""
Expand All @@ -75,7 +75,7 @@ def _join_parents(
df: pd.DataFrame,
table: str,
lineage: str,
tableset: Optional[Dict[str, pd.DataFrame]],
tableset: Optional[dict[str, pd.DataFrame]],
ancestral_seeding: bool,
) -> pd.DataFrame:
for foreign_key in rel_data.get_foreign_keys(table):
Expand Down Expand Up @@ -134,7 +134,7 @@ def drop_ancestral_data(df: pd.DataFrame) -> pd.DataFrame:
return df[root_columns].rename(columns=mapper)


def prepend_foreign_key_lineage(df: pd.DataFrame, fk_cols: List[str]) -> pd.DataFrame:
def prepend_foreign_key_lineage(df: pd.DataFrame, fk_cols: list[str]) -> pd.DataFrame:
"""
Given a multigenerational dataframe, renames all columns such that the provided
foreign key columns act as the lineage from some child table to the provided data.
Expand Down
36 changes: 18 additions & 18 deletions src/gretel_trainer/relational/backup.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
from __future__ import annotations

from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional
from typing import Any, Optional

from gretel_trainer.relational.artifacts import ArtifactCollection
from gretel_trainer.relational.core import ForeignKey, RelationalData


@dataclass
class BackupRelationalDataTable:
primary_key: List[str]
primary_key: list[str]
invented_table_metadata: Optional[dict[str, str]] = None


@dataclass
class BackupForeignKey:
table: str
constrained_columns: List[str]
constrained_columns: list[str]
referred_table: str
referred_columns: List[str]
referred_columns: list[str]

@classmethod
def from_fk(cls, fk: ForeignKey) -> BackupForeignKey:
Expand All @@ -41,9 +41,9 @@ class BackupRelationalJson:

@dataclass
class BackupRelationalData:
tables: Dict[str, BackupRelationalDataTable]
foreign_keys: List[BackupForeignKey]
relational_jsons: Dict[str, BackupRelationalJson]
tables: dict[str, BackupRelationalDataTable]
foreign_keys: list[BackupForeignKey]
relational_jsons: dict[str, BackupRelationalJson]

@classmethod
def from_relational_data(cls, rel_data: RelationalData) -> BackupRelationalData:
Expand Down Expand Up @@ -80,30 +80,30 @@ def from_relational_data(cls, rel_data: RelationalData) -> BackupRelationalData:

@dataclass
class BackupClassify:
model_ids: Dict[str, str]
model_ids: dict[str, str]


@dataclass
class BackupTransformsTrain:
model_ids: Dict[str, str]
lost_contact: List[str]
model_ids: dict[str, str]
lost_contact: list[str]


@dataclass
class BackupSyntheticsTrain:
model_ids: Dict[str, str]
lost_contact: List[str]
training_columns: Dict[str, List[str]]
model_ids: dict[str, str]
lost_contact: list[str]
training_columns: dict[str, list[str]]


@dataclass
class BackupGenerate:
identifier: str
preserved: List[str]
preserved: list[str]
record_size_ratio: float
record_handler_ids: Dict[str, str]
lost_contact: List[str]
missing_model: List[str]
record_handler_ids: dict[str, str]
lost_contact: list[str]
missing_model: list[str]


@dataclass
Expand All @@ -125,7 +125,7 @@ def as_dict(self):
return asdict(self)

@classmethod
def from_dict(cls, b: Dict[str, Any]):
def from_dict(cls, b: dict[str, Any]):
relational_data = b["relational_data"]
brd = BackupRelationalData(
tables={
Expand Down
36 changes: 18 additions & 18 deletions src/gretel_trainer/relational/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, replace
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union

import networkx
import pandas as pd
Expand All @@ -25,18 +25,18 @@ class MultiTableException(Exception):
pass


GretelModelConfig = Union[str, Path, Dict]
GretelModelConfig = Union[str, Path, dict]


@dataclass
class ForeignKey:
table_name: str
columns: List[str]
columns: list[str]
parent_table_name: str
parent_columns: List[str]
parent_columns: list[str]


UserFriendlyPrimaryKeyT = Optional[Union[str, List[str]]]
UserFriendlyPrimaryKeyT = Optional[Union[str, list[str]]]


class Scope(str, Enum):
Expand Down Expand Up @@ -234,7 +234,7 @@ def _remove_relational_json(

return original_data, original_primary_key, original_foreign_keys

def _format_key_column(self, key: Optional[Union[str, List[str]]]) -> List[str]:
def _format_key_column(self, key: Optional[Union[str, list[str]]]) -> list[str]:
if key is None:
return []
elif isinstance(key, str):
Expand Down Expand Up @@ -265,9 +265,9 @@ def add_foreign_key_constraint(
self,
*,
table: str,
constrained_columns: List[str],
constrained_columns: list[str],
referred_table: str,
referred_columns: List[str],
referred_columns: list[str],
) -> None:
"""
Add a foreign key relationship between two tables.
Expand Down Expand Up @@ -342,7 +342,7 @@ def remove_foreign_key(self, foreign_key: str) -> None:
)

def remove_foreign_key_constraint(
self, table: str, constrained_columns: List[str]
self, table: str, constrained_columns: list[str]
) -> None:
"""
Remove an existing foreign key.
Expand Down Expand Up @@ -420,7 +420,7 @@ def update_table_data(self, table: str, data: pd.DataFrame) -> None:
metadata.columns = set(data.columns)
self._clear_safe_ancestral_seed_columns(table)

def list_all_tables(self, scope: Scope = Scope.MODELABLE) -> List[str]:
def list_all_tables(self, scope: Scope = Scope.MODELABLE) -> list[str]:
modelable_nodes = self.graph.nodes

json_source_tables = [
Expand Down Expand Up @@ -488,10 +488,10 @@ def get_invented_table_metadata(

return self.graph.nodes[table]["metadata"].invented_table_metadata

def get_parents(self, table: str) -> List[str]:
def get_parents(self, table: str) -> list[str]:
return list(self.graph.successors(table))

def get_ancestors(self, table: str) -> List[str]:
def get_ancestors(self, table: str) -> list[str]:
def _add_parents(ancestors, table):
parents = self.get_parents(table)
if len(parents) > 0:
Expand All @@ -504,7 +504,7 @@ def _add_parents(ancestors, table):

return list(ancestors)

def get_descendants(self, table: str) -> List[str]:
def get_descendants(self, table: str) -> list[str]:
def _add_children(descendants, table):
children = list(self.graph.predecessors(table))
if len(children) > 0:
Expand All @@ -517,7 +517,7 @@ def _add_children(descendants, table):

return list(descendants)

def list_tables_parents_before_children(self) -> List[str]:
def list_tables_parents_before_children(self) -> list[str]:
"""
Returns a list of all tables with the guarantee that a parent table
appears before any of its children. No other guarantees about order
Expand All @@ -526,7 +526,7 @@ def list_tables_parents_before_children(self) -> List[str]:
"""
return list(reversed(list(topological_sort(self.graph))))

def get_primary_key(self, table: str) -> List[str]:
def get_primary_key(self, table: str) -> list[str]:
try:
return self.graph.nodes[table]["metadata"].primary_key
except KeyError:
Expand Down Expand Up @@ -592,7 +592,7 @@ def _get_table_in_graph(self, table: str) -> str:

def get_foreign_keys(
self, table: str, rename_invented_tables: bool = False
) -> List[ForeignKey]:
) -> list[ForeignKey]:
def _rename_invented(fk: ForeignKey) -> ForeignKey:
table_name = fk.table_name
parent_table_name = fk.parent_table_name
Expand All @@ -619,14 +619,14 @@ def _rename_invented(fk: ForeignKey) -> ForeignKey:
else:
return foreign_keys

def get_all_key_columns(self, table: str) -> List[str]:
def get_all_key_columns(self, table: str) -> list[str]:
all_key_cols = []
all_key_cols.extend(self.get_primary_key(table))
for fk in self.get_foreign_keys(table):
all_key_cols.extend(fk.columns)
return all_key_cols

def debug_summary(self) -> Dict[str, Any]:
def debug_summary(self) -> dict[str, Any]:
max_depth = dag_longest_path_length(self.graph)
public_table_count = len(self.list_all_tables(Scope.PUBLIC))
invented_table_count = len(self.list_all_tables(Scope.INVENTED))
Expand Down
14 changes: 7 additions & 7 deletions src/gretel_trainer/relational/model_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Any, Dict, List
from typing import Any

from gretel_client.projects.models import read_model_config

Expand All @@ -10,7 +10,7 @@
)


def _ingest(config: GretelModelConfig) -> Dict[str, Any]:
def _ingest(config: GretelModelConfig) -> dict[str, Any]:
return read_model_config(deepcopy(config))


Expand All @@ -19,27 +19,27 @@ def _model_name(workflow: str, table: str) -> str:
return f"{workflow}-{ok_table_name}"


def make_classify_config(table: str, config: GretelModelConfig) -> Dict[str, Any]:
def make_classify_config(table: str, config: GretelModelConfig) -> dict[str, Any]:
tailored_config = _ingest(config)
tailored_config["name"] = _model_name("classify", table)
return tailored_config


def make_evaluate_config(table: str) -> Dict[str, Any]:
def make_evaluate_config(table: str) -> dict[str, Any]:
tailored_config = _ingest("evaluate/default")
tailored_config["name"] = _model_name("evaluate", table)
return tailored_config


def make_synthetics_config(table: str, config: GretelModelConfig) -> Dict[str, Any]:
def make_synthetics_config(table: str, config: GretelModelConfig) -> dict[str, Any]:
tailored_config = _ingest(config)
tailored_config["name"] = _model_name("synthetics", table)
return tailored_config


def make_transform_config(
rel_data: RelationalData, table: str, config: GretelModelConfig
) -> Dict[str, Any]:
) -> dict[str, Any]:
tailored_config = _ingest(config)
tailored_config["name"] = _model_name("transforms", table)

Expand All @@ -65,7 +65,7 @@ def make_transform_config(
return tailored_config


def _passthrough_policy(columns: List[str]) -> Dict[str, Any]:
def _passthrough_policy(columns: list[str]) -> dict[str, Any]:
return {
"name": "ignore-keys",
"rules": [
Expand Down
Loading

0 comments on commit 7047e94

Please sign in to comment.