diff --git a/src/gretel_trainer/relational/extractor.py b/src/gretel_trainer/relational/extractor.py index a99ac90a..38c73fbb 100644 --- a/src/gretel_trainer/relational/extractor.py +++ b/src/gretel_trainer/relational/extractor.py @@ -349,7 +349,7 @@ def _load_table_pk_values( parent_column_names: list[str] = [] pk_set = set(self._relational_data.get_primary_key(table_name)) logger.debug( - f"Extacting primary key values for sampling from table '{table_name}'" + f"Extracting primary key values for sampling from table '{table_name}'" ) for child_table_name in child_table_names: @@ -380,8 +380,18 @@ def _load_table_pk_values( else: values_ddf = dd.concat([values_ddf, tmp_ddf]) # pyright: ignore - if parent_column_names: - values_ddf = values_ddf.drop_duplicates() # pyright: ignore + # Dropping the duplicates *only* works consistently + # when operating on a specific subset of columns using the [] + # notation. Using the "subset=" kwarg does not work, and neither + # does operating on the entire DDF. + if parent_column_names and values_ddf is not None: + values_ddf = values_ddf[ # pyright: ignore + parent_column_names + ].drop_duplicates() # pyright: ignore + else: + raise TableExtractorError( + f"Could not extract primary key values needed to sample from table `{table_name}`" + ) return _PKValues( table_name=table_name, @@ -422,7 +432,18 @@ def handle_partition(df: pd.DataFrame, lock: Lock): logger.debug( f"Sampling primary key values for parent table '{pk_values.table_name}'" ) - pk_values.values_ddf.map_partitions(handle_partition, lock).compute() + + # By providing the "meta" kwarg, this prevents + # Dask from running the map function ("handle_partition") on + # dummy data in an attempt to infer the metdata (which we don't + # need for the purposes of making the SQL queries). When this + # dummy partition is mapped, it was using the values in the + # partition to make additional SQL queries which can have + # unintended side effects. See the "map_partition" docs + # for more details if interested. + pk_values.values_ddf.map_partitions( + handle_partition, lock, meta=(None, "object") + ).compute() return row_count