diff --git a/singer_sdk/streams/sql.py b/singer_sdk/streams/sql.py index 114c01674..6a0392485 100644 --- a/singer_sdk/streams/sql.py +++ b/singer_sdk/streams/sql.py @@ -12,6 +12,7 @@ from sqlalchemy.engine import Engine from sqlalchemy.engine.reflection import Inspector +import singer_sdk.helpers._catalog as catalog from singer_sdk import typing as th from singer_sdk._singerlib import CatalogEntry, MetadataMapping, Schema from singer_sdk.exceptions import ConfigValidationError @@ -325,11 +326,7 @@ def get_object_names( # Some DB providers do not understand 'views' self._warn_no_view_detection() view_names = [] - object_names = [(t, False) for t in table_names] + [ - (v, True) for v in view_names - ] - - return object_names + return [(t, False) for t in table_names] + [(v, True) for v in view_names] # TODO maybe should be splitted into smaller parts? def discover_catalog_entry( @@ -365,9 +362,13 @@ def discover_catalog_entry( pk_def = inspected.get_pk_constraint(table_name, schema=schema_name) if pk_def and "constrained_columns" in pk_def: possible_primary_keys.append(pk_def["constrained_columns"]) - for index_def in inspected.get_indexes(table_name, schema=schema_name): - if index_def.get("unique", False): - possible_primary_keys.append(index_def["column_names"]) + + possible_primary_keys.extend( + index_def["column_names"] + for index_def in inspected.get_indexes(table_name, schema=schema_name) + if index_def.get("unique", False) + ) + key_properties = next(iter(possible_primary_keys), None) # Initialize columns list @@ -397,7 +398,7 @@ def discover_catalog_entry( replication_method = next(reversed(["FULL_TABLE"] + addl_replication_methods)) # Create the catalog entry object - catalog_entry = CatalogEntry( + return CatalogEntry( tap_stream_id=unique_stream_id, stream=unique_stream_id, table=table_name, @@ -418,8 +419,6 @@ def discover_catalog_entry( replication_key=None, # Must be defined by user ) - return catalog_entry - def discover_catalog_entries(self) -> list[dict]: """Return a list of catalog entries from discovery. @@ -488,11 +487,14 @@ def table_exists(self, full_table_name: str) -> bool: sqlalchemy.inspect(self._engine).has_table(full_table_name), ) - def get_table_columns(self, full_table_name: str) -> dict[str, sqlalchemy.Column]: + def get_table_columns( + self, full_table_name: str, column_names: list[str] | None = None + ) -> dict[str, sqlalchemy.Column]: """Return a list of table columns. Args: full_table_name: Fully qualified table name. + column_names: A list of column names to filter to. Returns: An ordered list of column objects. @@ -501,26 +503,32 @@ def get_table_columns(self, full_table_name: str) -> dict[str, sqlalchemy.Column inspector = sqlalchemy.inspect(self._engine) columns = inspector.get_columns(table_name, schema_name) - result: dict[str, sqlalchemy.Column] = {} - for col_meta in columns: - result[col_meta["name"]] = sqlalchemy.Column( + return { + col_meta["name"]: sqlalchemy.Column( col_meta["name"], col_meta["type"], nullable=col_meta.get("nullable", False), ) - - return result - - def get_table(self, full_table_name: str) -> sqlalchemy.Table: + for col_meta in columns + if not column_names + or col_meta["name"].casefold() in {col.casefold() for col in column_names} + } + + def get_table( + self, full_table_name: str, column_names: list[str] | None = None + ) -> sqlalchemy.Table: """Return a table object. Args: full_table_name: Fully qualified table name. + column_names: A list of column names to filter to. Returns: A table object with column list. """ - columns = self.get_table_columns(full_table_name).values() + columns = self.get_table_columns( + full_table_name=full_table_name, column_names=column_names + ).values() _, schema_name, table_name = self.parse_full_table_name(full_table_name) meta = sqlalchemy.MetaData() return sqlalchemy.schema.Table( @@ -910,11 +918,7 @@ def __init__( connector: Optional connector to reuse. """ self._connector: SQLConnector - if connector: - self._connector = connector - else: - self._connector = self.connector_class(dict(tap.config)) - + self._connector = connector or self.connector_class(dict(tap.config)) self.catalog_entry = catalog_entry super().__init__( tap=tap, @@ -1016,8 +1020,21 @@ def fully_qualified_name(self) -> str: db_name=catalog_entry.database, ) - # Get records from stream + def get_selected_schema(self) -> dict: + """Return a copy of the Stream JSON schema, dropping any fields not selected. + Returns: + A dictionary containing a copy of the Stream JSON schema, filtered + to any selection criteria. + """ + return catalog.get_selected_schema( + stream_name=self.name, + schema=self.schema, + mask=self.mask, + logger=self.logger, + ) + + # Get records from stream def get_records(self, context: dict | None) -> Iterable[dict[str, Any]]: """Return a generator of record-type dictionary objects. @@ -1041,7 +1058,11 @@ def get_records(self, context: dict | None) -> Iterable[dict[str, Any]]: f"Stream '{self.name}' does not support partitioning." ) - table = self.connector.get_table(self.fully_qualified_name) + selected_column_names = self.get_selected_schema()["properties"].keys() + table = self.connector.get_table( + full_table_name=self.fully_qualified_name, + column_names=selected_column_names, + ) query = table.select() if self.replication_key: replication_key_col = table.columns[self.replication_key]