Skip to content

Commit

Permalink
table extractor dask fixes (#124)
Browse files Browse the repository at this point in the history
* fixes
  • Loading branch information
johntmyers authored Jun 14, 2023
1 parent 6b450d6 commit 0888bc8
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions src/gretel_trainer/relational/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def _load_table_pk_values(
parent_column_names: list[str] = []
pk_set = set(self._relational_data.get_primary_key(table_name))
logger.debug(
f"Extacting primary key values for sampling from table '{table_name}'"
f"Extracting primary key values for sampling from table '{table_name}'"
)

for child_table_name in child_table_names:
Expand Down Expand Up @@ -380,8 +380,18 @@ def _load_table_pk_values(
else:
values_ddf = dd.concat([values_ddf, tmp_ddf]) # pyright: ignore

if parent_column_names:
values_ddf = values_ddf.drop_duplicates() # pyright: ignore
# Dropping the duplicates *only* works consistently
# when operating on a specific subset of columns using the []
# notation. Using the "subset=" kwarg does not work, and neither
# does operating on the entire DDF.
if parent_column_names and values_ddf is not None:
values_ddf = values_ddf[ # pyright: ignore
parent_column_names
].drop_duplicates() # pyright: ignore
else:
raise TableExtractorError(
f"Could not extract primary key values needed to sample from table `{table_name}`"
)

return _PKValues(
table_name=table_name,
Expand Down Expand Up @@ -422,7 +432,18 @@ def handle_partition(df: pd.DataFrame, lock: Lock):
logger.debug(
f"Sampling primary key values for parent table '{pk_values.table_name}'"
)
pk_values.values_ddf.map_partitions(handle_partition, lock).compute()

# By providing the "meta" kwarg, this prevents
# Dask from running the map function ("handle_partition") on
# dummy data in an attempt to infer the metdata (which we don't
# need for the purposes of making the SQL queries). When this
# dummy partition is mapped, it was using the values in the
# partition to make additional SQL queries which can have
# unintended side effects. See the "map_partition" docs
# for more details if interested.
pk_values.values_ddf.map_partitions(
handle_partition, lock, meta=(None, "object")
).compute()

return row_count

Expand Down

0 comments on commit 0888bc8

Please sign in to comment.