Skip to content

Commit

Permalink
Add session pass through
Browse files Browse the repository at this point in the history
  • Loading branch information
mckornfield committed Feb 6, 2024
1 parent cc4e406 commit ca33d73
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/gretel_trainer/relational/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def extract(
extractor = TableExtractor(
config=config, connector=self, storage_dir=storage_dir_path
)
extractor.sample_tables()
extractor.sample_tables(schema=schema)

# We ensure to re-create RelationalData after extraction so
# we can account for any embedded JSON. This also loads
Expand Down
20 changes: 14 additions & 6 deletions src/gretel_trainer/relational/extractor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Extract database or data warehouse SQL tables to flat files with optional subsetting.
"""

from __future__ import annotations

import logging
Expand Down Expand Up @@ -242,9 +243,11 @@ def __init__(
self.table_order = []
self._chunk_size = 50_000

def _get_table_session(self, table_name: str) -> _TableSession:
def _get_table_session(
self, table_name: str, schema: Optional[str] = None
) -> _TableSession:
metadata = MetaData()
metadata.reflect(only=[table_name], bind=self._connector.engine)
metadata.reflect(only=[table_name], bind=self._connector.engine, schema=schema)
table = metadata.tables[table_name]
return _TableSession(table=table, engine=self._connector.engine)

Expand Down Expand Up @@ -515,13 +518,16 @@ def _flat_sample(
)

def _sample_table(
self, table_name: str, child_tables: Optional[list[str]] = None
self,
table_name: str,
child_tables: Optional[list[str]] = None,
schema: Optional[str] = None,
) -> TableMetadata:
if self._relational_data.is_empty:
self._extract_schema()

table_path = self._table_path(table_name)
table_session = self._get_table_session(table_name)
table_session = self._get_table_session(table_name, schema=schema)
engine = self._connector.engine

# First we'll create our table file on disk and bootstrap
Expand Down Expand Up @@ -573,7 +579,7 @@ def _sample_table(
column_count=table_session.total_column_count,
)

def sample_tables(self) -> dict[str, TableMetadata]:
def sample_tables(self, schema: Optional[str] = None) -> dict[str, TableMetadata]:
"""
Extract database tables according to the `ExtractorConfig.` Tables will be stored in the
configured storage directory that is configured on the `ExtractorConfig` object.
Expand All @@ -584,7 +590,9 @@ def sample_tables(self) -> dict[str, TableMetadata]:
table_data = {}
for table_name in self.table_order:
child_tables = self._relational_data.get_descendants(table_name)
meta = self._sample_table(table_name, child_tables=child_tables)
meta = self._sample_table(
table_name, child_tables=child_tables, schema=schema
)
table_data[table_name] = meta

return table_data
Expand Down

0 comments on commit ca33d73

Please sign in to comment.