Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

table extractor dask fixes #124

Merged
merged 2 commits into from
Jun 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions src/gretel_trainer/relational/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def _load_table_pk_values(
parent_column_names: list[str] = []
pk_set = set(self._relational_data.get_primary_key(table_name))
logger.debug(
f"Extacting primary key values for sampling from table '{table_name}'"
f"Extracting primary key values for sampling from table '{table_name}'"
)

for child_table_name in child_table_names:
Expand Down Expand Up @@ -380,8 +380,18 @@ def _load_table_pk_values(
else:
values_ddf = dd.concat([values_ddf, tmp_ddf]) # pyright: ignore

if parent_column_names:
values_ddf = values_ddf.drop_duplicates() # pyright: ignore
# Dropping the duplicates *only* works consistently
# when operating on a specific subset of columns using the []
# notation. Using the "subset=" kwarg does not work, and neither
# does operating on the entire DDF.
if parent_column_names and values_ddf is not None:
values_ddf = values_ddf[ # pyright: ignore
parent_column_names
].drop_duplicates() # pyright: ignore
else:
raise TableExtractorError(
f"Could not extract primary key values needed to sample from table `{table_name}`"
)

return _PKValues(
table_name=table_name,
Expand Down Expand Up @@ -422,7 +432,18 @@ def handle_partition(df: pd.DataFrame, lock: Lock):
logger.debug(
f"Sampling primary key values for parent table '{pk_values.table_name}'"
)
pk_values.values_ddf.map_partitions(handle_partition, lock).compute()

# By providing the "meta" kwarg, this prevents
# Dask from running the map function ("handle_partition") on
# dummy data in an attempt to infer the metdata (which we don't
# need for the purposes of making the SQL queries). When this
# dummy partition is mapped, it was using the values in the
# partition to make additional SQL queries which can have
# unintended side effects. See the "map_partition" docs
# for more details if interested.
pk_values.values_ddf.map_partitions(
handle_partition, lock, meta=(None, "object")
).compute()

return row_count

Expand Down