Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support stream property selection push-down in SQL streams #1032

Merged
merged 9 commits into from
Oct 5, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 48 additions & 27 deletions singer_sdk/streams/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sqlalchemy.engine import Engine
from sqlalchemy.engine.reflection import Inspector

import singer_sdk.helpers._catalog as catalog
from singer_sdk import typing as th
from singer_sdk._singerlib import CatalogEntry, MetadataMapping, Schema
from singer_sdk.exceptions import ConfigValidationError
Expand Down Expand Up @@ -325,11 +326,7 @@ def get_object_names(
# Some DB providers do not understand 'views'
self._warn_no_view_detection()
view_names = []
object_names = [(t, False) for t in table_names] + [
(v, True) for v in view_names
]

return object_names
return [(t, False) for t in table_names] + [(v, True) for v in view_names]

# TODO maybe should be splitted into smaller parts?
def discover_catalog_entry(
Expand Down Expand Up @@ -365,9 +362,13 @@ def discover_catalog_entry(
pk_def = inspected.get_pk_constraint(table_name, schema=schema_name)
if pk_def and "constrained_columns" in pk_def:
possible_primary_keys.append(pk_def["constrained_columns"])
for index_def in inspected.get_indexes(table_name, schema=schema_name):
if index_def.get("unique", False):
possible_primary_keys.append(index_def["column_names"])

possible_primary_keys.extend(
index_def["column_names"]
for index_def in inspected.get_indexes(table_name, schema=schema_name)
if index_def.get("unique", False)
)

key_properties = next(iter(possible_primary_keys), None)

# Initialize columns list
Expand Down Expand Up @@ -397,7 +398,7 @@ def discover_catalog_entry(
replication_method = next(reversed(["FULL_TABLE"] + addl_replication_methods))

# Create the catalog entry object
catalog_entry = CatalogEntry(
return CatalogEntry(
tap_stream_id=unique_stream_id,
stream=unique_stream_id,
table=table_name,
Expand All @@ -418,8 +419,6 @@ def discover_catalog_entry(
replication_key=None, # Must be defined by user
)

return catalog_entry

def discover_catalog_entries(self) -> list[dict]:
"""Return a list of catalog entries from discovery.

Expand Down Expand Up @@ -488,11 +487,14 @@ def table_exists(self, full_table_name: str) -> bool:
sqlalchemy.inspect(self._engine).has_table(full_table_name),
)

def get_table_columns(self, full_table_name: str) -> dict[str, sqlalchemy.Column]:
def get_table_columns(
self, full_table_name: str, column_names: list[str] | None = None
) -> dict[str, sqlalchemy.Column]:
"""Return a list of table columns.

Args:
full_table_name: Fully qualified table name.
column_names: A list of column names to filter to.

Returns:
An ordered list of column objects.
Expand All @@ -501,26 +503,32 @@ def get_table_columns(self, full_table_name: str) -> dict[str, sqlalchemy.Column
inspector = sqlalchemy.inspect(self._engine)
columns = inspector.get_columns(table_name, schema_name)

result: dict[str, sqlalchemy.Column] = {}
for col_meta in columns:
result[col_meta["name"]] = sqlalchemy.Column(
return {
col_meta["name"]: sqlalchemy.Column(
col_meta["name"],
col_meta["type"],
nullable=col_meta.get("nullable", False),
)

return result

def get_table(self, full_table_name: str) -> sqlalchemy.Table:
for col_meta in columns
if not column_names
or col_meta["name"].casefold() in {col.casefold() for col in column_names}
kgpayne marked this conversation as resolved.
Show resolved Hide resolved
}

def get_table(
self, full_table_name: str, column_names: list[str] | None = None
) -> sqlalchemy.Table:
"""Return a table object.

Args:
full_table_name: Fully qualified table name.
column_names: A list of column names to filter to.

Returns:
A table object with column list.
"""
columns = self.get_table_columns(full_table_name).values()
columns = self.get_table_columns(
full_table_name=full_table_name, column_names=column_names
).values()
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
meta = sqlalchemy.MetaData()
return sqlalchemy.schema.Table(
Expand Down Expand Up @@ -910,11 +918,7 @@ def __init__(
connector: Optional connector to reuse.
"""
self._connector: SQLConnector
if connector:
self._connector = connector
else:
self._connector = self.connector_class(dict(tap.config))

self._connector = connector or self.connector_class(dict(tap.config))
self.catalog_entry = catalog_entry
super().__init__(
tap=tap,
Expand Down Expand Up @@ -1016,8 +1020,21 @@ def fully_qualified_name(self) -> str:
db_name=catalog_entry.database,
)

# Get records from stream
def get_selected_schema(self) -> dict:
"""Return a copy of the Stream JSON schema, dropping any fields not selected.

Returns:
A dictionary containing a copy of the Stream JSON schema, filtered
to any selection criteria.
"""
return catalog.get_selected_schema(
stream_name=self.name,
schema=self.schema,
mask=self.mask,
logger=self.logger,
)

# Get records from stream
def get_records(self, context: dict | None) -> Iterable[dict[str, Any]]:
"""Return a generator of record-type dictionary objects.

Expand All @@ -1041,7 +1058,11 @@ def get_records(self, context: dict | None) -> Iterable[dict[str, Any]]:
f"Stream '{self.name}' does not support partitioning."
)

table = self.connector.get_table(self.fully_qualified_name)
selected_column_names = self.get_selected_schema()["properties"].keys()
table = self.connector.get_table(
full_table_name=self.fully_qualified_name,
column_names=selected_column_names,
)
query = table.select()
if self.replication_key:
replication_key_col = table.columns[self.replication_key]
Expand Down