Skip to content

Commit

Permalink
Fix a couple SQL related errors
Browse files Browse the repository at this point in the history
  • Loading branch information
mckornfield committed Sep 12, 2023
1 parent 299a29e commit 18f5de1
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions src/gretel_trainer/relational/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,19 @@
"""
from __future__ import annotations

import itertools
import logging

from contextlib import nullcontext
from dataclasses import asdict, dataclass
from enum import Enum
from pathlib import Path
from threading import Lock
from typing import Iterator, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Iterator, Optional

import dask.dataframe as dd
import numpy as np
import pandas as pd

from sqlalchemy import func, inspect, MetaData, select, Table, tuple_
from sqlalchemy import MetaData, Table, and_, func, inspect, or_, select, tuple_

from gretel_trainer.relational.core import RelationalData

Expand Down Expand Up @@ -380,7 +379,9 @@ def _load_table_pk_values(
child_table_path = self._table_path(child_table_name)

tmp_ddf = dd.read_csv( # pyright: ignore
str(child_table_path), usecols=fk.columns
str(child_table_path),
usecols=fk.columns,
dtype={key: "object" for key in fk.columns},
)
tmp_ddf = tmp_ddf.rename(columns=rename_map)
if values_ddf is None:
Expand Down Expand Up @@ -422,16 +423,21 @@ def handle_partition(df: pd.DataFrame, lock: Lock):
table_session = self._get_table_session(pk_values.table_name)
nonlocal row_count

chunk_size = 150 # limit how many checks go into a WHERE clause
chunk_size = 15_000 # limit how many checks go into a WHERE clause

for _, chunk_df in df.groupby(np.arange(len(df)) // chunk_size):
values_list = chunk_df.to_records(index=False).tolist()
query = table_session.table.select().where(
tuple_(
*[table_session.table.c[col] for col in pk_values.column_names]
).in_(values_list)
)

columns = [table_session.table.c[col] for col in pk_values.column_names]
# Produces a query that looks like
# SELECT * FROM TABLE WHERE (TABLE.A = 1 AND TABLE.B = 2) OR ...
column_comparisons = [
and_(
column == value
for column, value in zip(itertools.cycle(columns), values)
).self_group()
for values in values_list
]
query = table_session.table.select().where(or_(*column_comparisons))
with table_session.engine.connect() as conn:
df_iter = pd.read_sql_query(query, conn, chunksize=self._chunk_size)
write_count = _stream_df_to_path(df_iter, table_path, lock=lock)
Expand Down Expand Up @@ -542,7 +548,6 @@ def _sample_table(
if self._config.entire_table:
logger.debug(f"Extracting entire table: {table_name}")
with engine.connect() as conn:
# TODO: Add a loading percentage here?
df_iter = pd.read_sql_table(
table_name, conn, chunksize=self._chunk_size
)
Expand Down

0 comments on commit 18f5de1

Please sign in to comment.