From 20cd58419e8b1bc031f6ff2de53364dce44904c8 Mon Sep 17 00:00:00 2001 From: Mayuri N Date: Mon, 25 Jul 2022 19:35:59 +0530 Subject: [PATCH] feat(ingest): add snowflake-beta source --- .../docs/sources/snowflake/README.md | 5 +- .../docs/sources/snowflake/snowflake-beta.md | 57 ++ .../snowflake/snowflake-beta_recipe.yml | 46 ++ metadata-ingestion/setup.py | 2 + .../ingestion/source/snowflake/__init__.py | 1 + .../source/snowflake/snowflake_lineage.py | 434 +++++++++++ .../source/snowflake/snowflake_schema.py | 468 ++++++++++++ .../source/snowflake/snowflake_v2.py | 700 ++++++++++++++++++ .../datahub/ingestion/source/sql/snowflake.py | 417 +---------- 9 files changed, 1732 insertions(+), 398 deletions(-) create mode 100644 metadata-ingestion/docs/sources/snowflake/snowflake-beta.md create mode 100644 metadata-ingestion/docs/sources/snowflake/snowflake-beta_recipe.yml create mode 100644 metadata-ingestion/src/datahub/ingestion/source/snowflake/__init__.py create mode 100644 metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage.py create mode 100644 metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py create mode 100644 metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py diff --git a/metadata-ingestion/docs/sources/snowflake/README.md b/metadata-ingestion/docs/sources/snowflake/README.md index 62c6c155135570..c48e9329d7d61b 100644 --- a/metadata-ingestion/docs/sources/snowflake/README.md +++ b/metadata-ingestion/docs/sources/snowflake/README.md @@ -1 +1,4 @@ -To get all metadata from Snowflake you need to use two plugins `snowflake` and `snowflake-usage`. Both of them are described in this page. These will require 2 separate recipes. We understand this is not ideal and we plan to make this easier in the future. +To get all metadata from Snowflake you need to use plugins `snowflake` and `snowflake-usage`. Both of them are described in this page. These will require 2 separate recipes. We understand this is not ideal and we plan to make this easier in the future. + + +We encourage you to try out new `snowflake-beta` plugin as alternative to `snowflake` plugin and share feedback. `snowflake-beta` is much faster than `snowflake` for extracting metadata. Please note that, `snowflake-beta` plugin currently does not support stateful ingestion and column level profiling, unlike `snowflake` plugin. \ No newline at end of file diff --git a/metadata-ingestion/docs/sources/snowflake/snowflake-beta.md b/metadata-ingestion/docs/sources/snowflake/snowflake-beta.md new file mode 100644 index 00000000000000..d04797488e1c13 --- /dev/null +++ b/metadata-ingestion/docs/sources/snowflake/snowflake-beta.md @@ -0,0 +1,57 @@ +### Prerequisites + +In order to execute this source, your Snowflake user will need to have specific privileges granted to it for reading metadata +from your warehouse. + +Snowflake system admin can follow this guide to create a DataHub-specific role, assign it the required privileges, and assign it to a new DataHub user by executing the following Snowflake commands from a user with the `ACCOUNTADMIN` role or `MANAGE GRANTS` privilege. + +```sql +create or replace role datahub_role; + +// Grant access to a warehouse to run queries to view metadata +grant operate, usage on warehouse "" to role datahub_role; + +// Grant access to view database and schema in which your tables/views exist +grant usage on DATABASE "" to role datahub_role; +grant usage on all schemas in database "" to role datahub_role; +grant usage on future schemas in database "" to role datahub_role; + +// If you are NOT using Snowflake Profiling feature: Grant references privileges to your tables and views +grant references on all tables in database "" to role datahub_role; +grant references on future tables in database "" to role datahub_role; +grant references on all external tables in database "" to role datahub_role; +grant references on future external tables in database "" to role datahub_role; +grant references on all views in database "" to role datahub_role; +grant references on future views in database "" to role datahub_role; + +// If you ARE using Snowflake Profiling feature: Grant select privileges to your tables and views +grant select on all tables in database "" to role datahub_role; +grant select on future tables in database "" to role datahub_role; +grant select on all external tables in database "" to role datahub_role; +grant select on future external tables in database "" to role datahub_role; +grant select on all views in database "" to role datahub_role; +grant select on future views in database "" to role datahub_role; + +// Create a new DataHub user and assign the DataHub role to it +create user datahub_user display_name = 'DataHub' password='' default_role = datahub_role default_warehouse = ''; + +// Grant the datahub_role to the new DataHub user. +grant role datahub_role to user datahub_user; +``` + +The details of each granted privilege can be viewed in [snowflake docs](https://docs.snowflake.com/en/user-guide/security-access-control-privileges.html). A summarization of each privilege, and why it is required for this connector: +- `operate` is required on warehouse to execute queries +- `usage` is required for us to run queries using the warehouse +- `usage` on `database` and `schema` are required because without it tables and views inside them are not accessible. If an admin does the required grants on `table` but misses the grants on `schema` or the `database` in which the table/view exists then we will not be able to get metadata for the table/view. +- If metadata is required only on some schemas then you can grant the usage privilieges only on a particular schema like +```sql +grant usage on schema ""."" to role datahub_role; +``` +- To get the lineage and usage data we need access to the default `snowflake` database + +This represents the bare minimum privileges required to extract databases, schemas, views, tables from Snowflake. + +If you plan to enable extraction of table lineage, via the `include_table_lineage` config flag, you'll also need to grant access to the [Account Usage](https://docs.snowflake.com/en/sql-reference/account-usage.html) system tables, using which the DataHub source extracts information. +```sql +grant imported privileges on database snowflake to role datahub_role; +``` \ No newline at end of file diff --git a/metadata-ingestion/docs/sources/snowflake/snowflake-beta_recipe.yml b/metadata-ingestion/docs/sources/snowflake/snowflake-beta_recipe.yml new file mode 100644 index 00000000000000..3f310447eb33c0 --- /dev/null +++ b/metadata-ingestion/docs/sources/snowflake/snowflake-beta_recipe.yml @@ -0,0 +1,46 @@ +source: + type: snowflake-beta + config: + + # This option is recommended to be used for the first time to ingest all lineage + ignore_start_time_lineage: true + # This is an alternative option to specify the start_time for lineage + # if you don't want to look back since beginning + start_time: '2022-03-01T00:00:00Z' + + # Coordinates + account_id: "abc48144" + warehouse: "COMPUTE_WH" + + # Credentials + username: "${SNOWFLAKE_USER}" + password: "${SNOWFLAKE_PASS}" + role: "datahub_role" + + # Change these as per your database names. Remove to all all databases + database_pattern: + allow: + - "^ACCOUNTING_DB$" + - "^MARKETING_DB$" + schema_pattern: + deny: + - "information_schema.*" + table_pattern: + allow: + # If you want to ingest only few tables with name revenue and sales + - ".*revenue" + - ".*sales" + + profiling: + # Change to false to disable profiling + enabled: true + profile_table_level_only: true + profile_pattern: + allow: + - 'ACCOUNTING_DB.*.*' + - 'MARKETING_DB.*.*' + deny: + - '.*information_schema.*' + +sink: +# sink configs \ No newline at end of file diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index 89cf4e9ee39c98..1647f31612f22c 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -302,6 +302,7 @@ def get_long_description(): | { "more-itertools>=8.12.0", }, + "snowflake-beta": snowflake_common, "sqlalchemy": sql_common, "superset": { "requests", @@ -533,6 +534,7 @@ def get_long_description(): "redshift-usage = datahub.ingestion.source.usage.redshift_usage:RedshiftUsageSource", "snowflake = datahub.ingestion.source.sql.snowflake:SnowflakeSource", "snowflake-usage = datahub.ingestion.source.usage.snowflake_usage:SnowflakeUsageSource", + "snowflake-beta = datahub.ingestion.source.snowflake.snowflake_v2:SnowflakeV2Source", "superset = datahub.ingestion.source.superset:SupersetSource", "tableau = datahub.ingestion.source.tableau:TableauSource", "openapi = datahub.ingestion.source.openapi:OpenApiSource", diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/__init__.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/__init__.py new file mode 100644 index 00000000000000..3fe7e5a34277e1 --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/__init__.py @@ -0,0 +1 @@ +from datahub.ingestion.source.s3.source import S3Source diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage.py new file mode 100644 index 00000000000000..67144f5f414de2 --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage.py @@ -0,0 +1,434 @@ +import json +import logging +from collections import defaultdict +from typing import Dict, List, Optional, Set, Tuple + +import sqlalchemy +from sqlalchemy import create_engine + +import datahub.emitter.mce_builder as builder +from datahub.ingestion.source.aws.s3_util import make_s3_urn +from datahub.ingestion.source_config.sql.snowflake import SnowflakeConfig +from datahub.ingestion.source_report.sql.snowflake import SnowflakeReport +from datahub.metadata.com.linkedin.pegasus2avro.dataset import UpstreamLineage +from datahub.metadata.schema_classes import DatasetLineageTypeClass, UpstreamClass + +logger: logging.Logger = logging.getLogger(__name__) + + +class SnowflakeLineageExtractor: + def __init__(self, config: SnowflakeConfig, report: SnowflakeReport) -> None: + self._lineage_map: Optional[Dict[str, List[Tuple[str, str, str]]]] = None + self._external_lineage_map: Optional[Dict[str, Set[str]]] = None + self.config = config + self.platform = "snowflake" + self.report = report + + def _get_upstream_lineage_info( + self, dataset_name: str + ) -> Optional[Tuple[UpstreamLineage, Dict[str, str]]]: + + if not self.config.include_table_lineage: + return None + + if self._lineage_map is None: + engine = self.get_metadata_engine_for_lineage() + self._populate_lineage(engine) + self._populate_view_lineage(engine) + if self._external_lineage_map is None: + engine = self.get_metadata_engine_for_lineage() + self._populate_external_lineage(engine) + + assert self._lineage_map is not None + assert self._external_lineage_map is not None + + lineage = self._lineage_map[dataset_name] + external_lineage = self._external_lineage_map[dataset_name] + if not (lineage or external_lineage): + logger.debug(f"No lineage found for {dataset_name}") + return None + upstream_tables: List[UpstreamClass] = [] + column_lineage: Dict[str, str] = {} + for lineage_entry in lineage: + # Update the table-lineage + upstream_table_name = lineage_entry[0] + if not self._is_dataset_allowed(upstream_table_name): + continue + upstream_table = UpstreamClass( + dataset=builder.make_dataset_urn_with_platform_instance( + self.platform, + upstream_table_name, + self.config.platform_instance, + self.config.env, + ), + type=DatasetLineageTypeClass.TRANSFORMED, + ) + upstream_tables.append(upstream_table) + # Update column-lineage for each down-stream column. + upstream_columns = [ + d["columnName"].lower() for d in json.loads(lineage_entry[1]) + ] + downstream_columns = [ + d["columnName"].lower() for d in json.loads(lineage_entry[2]) + ] + upstream_column_str = ( + f"{upstream_table_name}({', '.join(sorted(upstream_columns))})" + ) + downstream_column_str = ( + f"{dataset_name}({', '.join(sorted(downstream_columns))})" + ) + column_lineage_key = f"column_lineage[{upstream_table_name}]" + column_lineage_value = ( + f"{{{upstream_column_str} -> {downstream_column_str}}}" + ) + column_lineage[column_lineage_key] = column_lineage_value + logger.debug(f"{column_lineage_key}:{column_lineage_value}") + + for external_lineage_entry in external_lineage: + # For now, populate only for S3 + if external_lineage_entry.startswith("s3://"): + external_upstream_table = UpstreamClass( + dataset=make_s3_urn(external_lineage_entry, self.config.env), + type=DatasetLineageTypeClass.COPY, + ) + upstream_tables.append(external_upstream_table) + + if upstream_tables: + logger.debug( + f"Upstream lineage of '{dataset_name}': {[u.dataset for u in upstream_tables]}" + ) + if self.config.upstream_lineage_in_report: + self.report.upstream_lineage[dataset_name] = [ + u.dataset for u in upstream_tables + ] + return UpstreamLineage(upstreams=upstream_tables), column_lineage + return None + + def _populate_view_lineage(self, engine: sqlalchemy.engine.Engine) -> None: + if not self.config.include_view_lineage: + return + self._populate_view_upstream_lineage(engine) + self._populate_view_downstream_lineage(engine) + + def _populate_external_lineage(self, engine: sqlalchemy.engine.Engine) -> None: + # Handles the case where a table is populated from an external location via copy. + # Eg: copy into category_english from 's3://acryl-snow-demo-olist/olist_raw_data/category_english'credentials=(aws_key_id='...' aws_secret_key='...') pattern='.*.csv'; + query: str = """ + WITH external_table_lineage_history AS ( + SELECT + r.value:"locations" as upstream_locations, + w.value:"objectName" AS downstream_table_name, + w.value:"objectDomain" AS downstream_table_domain, + w.value:"columns" AS downstream_table_columns, + t.query_start_time AS query_start_time + FROM + (SELECT * from snowflake.account_usage.access_history) t, + lateral flatten(input => t.BASE_OBJECTS_ACCESSED) r, + lateral flatten(input => t.OBJECTS_MODIFIED) w + WHERE r.value:"locations" IS NOT NULL + AND w.value:"objectId" IS NOT NULL + AND t.query_start_time >= to_timestamp_ltz({start_time_millis}, 3) + AND t.query_start_time < to_timestamp_ltz({end_time_millis}, 3)) + SELECT upstream_locations, downstream_table_name, downstream_table_columns + FROM external_table_lineage_history + WHERE downstream_table_domain = 'Table' + QUALIFY ROW_NUMBER() OVER (PARTITION BY downstream_table_name ORDER BY query_start_time DESC) = 1""".format( + start_time_millis=int(self.config.start_time.timestamp() * 1000) + if not self.config.ignore_start_time_lineage + else 0, + end_time_millis=int(self.config.end_time.timestamp() * 1000), + ) + + num_edges: int = 0 + self._external_lineage_map = defaultdict(set) + try: + for db_row in engine.execute(query): + # key is the down-stream table name + key: str = db_row[1].lower().replace('"', "") + if not self._is_dataset_allowed(key): + continue + self._external_lineage_map[key] |= {*json.loads(db_row[0])} + logger.debug( + f"ExternalLineage[Table(Down)={key}]:External(Up)={self._external_lineage_map[key]} via access_history" + ) + except Exception as e: + logger.warning( + f"Populating table external lineage from Snowflake failed." + f"Please check your premissions. Continuing...\nError was {e}." + ) + # Handles the case for explicitly created external tables. + # NOTE: Snowflake does not log this information to the access_history table. + external_tables_query: str = "show external tables in account" + try: + for db_row in engine.execute(external_tables_query): + key = ( + f"{db_row.database_name}.{db_row.schema_name}.{db_row.name}".lower() + ) + if not self._is_dataset_allowed(dataset_name=key): + continue + self._external_lineage_map[key].add(db_row.location) + logger.debug( + f"ExternalLineage[Table(Down)={key}]:External(Up)={self._external_lineage_map[key]} via show external tables" + ) + num_edges += 1 + except Exception as e: + self.warn( + logger, + "external_lineage", + f"Populating external table lineage from Snowflake failed." + f"Please check your premissions. Continuing...\nError was {e}.", + ) + logger.info(f"Found {num_edges} external lineage edges.") + self.report.num_external_table_edges_scanned = num_edges + + def _populate_lineage(self, engine: sqlalchemy.engine.Engine) -> None: + query: str = """ +WITH table_lineage_history AS ( + SELECT + r.value:"objectName" AS upstream_table_name, + r.value:"objectDomain" AS upstream_table_domain, + r.value:"columns" AS upstream_table_columns, + w.value:"objectName" AS downstream_table_name, + w.value:"objectDomain" AS downstream_table_domain, + w.value:"columns" AS downstream_table_columns, + t.query_start_time AS query_start_time + FROM + (SELECT * from snowflake.account_usage.access_history) t, + lateral flatten(input => t.DIRECT_OBJECTS_ACCESSED) r, + lateral flatten(input => t.OBJECTS_MODIFIED) w + WHERE r.value:"objectId" IS NOT NULL + AND w.value:"objectId" IS NOT NULL + AND w.value:"objectName" NOT LIKE '%.GE_TMP_%' + AND w.value:"objectName" NOT LIKE '%.GE_TEMP_%' + AND t.query_start_time >= to_timestamp_ltz({start_time_millis}, 3) + AND t.query_start_time < to_timestamp_ltz({end_time_millis}, 3)) +SELECT upstream_table_name, downstream_table_name, upstream_table_columns, downstream_table_columns +FROM table_lineage_history +WHERE upstream_table_domain in ('Table', 'External table') and downstream_table_domain = 'Table' +QUALIFY ROW_NUMBER() OVER (PARTITION BY downstream_table_name, upstream_table_name ORDER BY query_start_time DESC) = 1 """.format( + start_time_millis=int(self.config.start_time.timestamp() * 1000) + if not self.config.ignore_start_time_lineage + else 0, + end_time_millis=int(self.config.end_time.timestamp() * 1000), + ) + num_edges: int = 0 + self._lineage_map = defaultdict(list) + try: + for db_row in engine.execute(query): + # key is the down-stream table name + key: str = db_row[1].lower().replace('"', "") + upstream_table_name = db_row[0].lower().replace('"', "") + if not ( + self._is_dataset_allowed(key) + or self._is_dataset_allowed(upstream_table_name) + ): + continue + self._lineage_map[key].append( + # (, , ) + (upstream_table_name, db_row[2], db_row[3]) + ) + num_edges += 1 + logger.debug( + f"Lineage[Table(Down)={key}]:Table(Up)={self._lineage_map[key]}" + ) + except Exception as e: + self.warn( + logger, + "lineage", + f"Extracting lineage from Snowflake failed." + f"Please check your premissions. Continuing...\nError was {e}.", + ) + logger.info( + f"A total of {num_edges} Table->Table edges found" + f" for {len(self._lineage_map)} downstream tables.", + ) + self.report.num_table_to_table_edges_scanned = num_edges + + def _is_dataset_allowed( + self, dataset_name: Optional[str], is_view: bool = False + ) -> bool: + # View lineages is not supported. Add the allow/deny pattern for that when it is supported. + if dataset_name is None: + return True + dataset_params = dataset_name.split(".") + if len(dataset_params) != 3: + return True + if ( + not self.config.database_pattern.allowed(dataset_params[0]) + or not self.config.schema_pattern.allowed(dataset_params[1]) + or ( + not is_view and not self.config.table_pattern.allowed(dataset_params[2]) + ) + or (is_view and not self.config.view_pattern.allowed(dataset_params[2])) + ): + return False + return True + + def warn(self, log: logging.Logger, key: str, reason: str) -> None: + self.report.report_warning(key, reason) + log.warning(f"{key} => {reason}") + + def _populate_view_upstream_lineage(self, engine: sqlalchemy.engine.Engine) -> None: + # NOTE: This query captures only the upstream lineage of a view (with no column lineage). + # For more details see: https://docs.snowflake.com/en/user-guide/object-dependencies.html#object-dependencies + # and also https://docs.snowflake.com/en/sql-reference/account-usage/access_history.html#usage-notes for current limitations on capturing the lineage for views. + view_upstream_lineage_query: str = """ +SELECT + concat( + referenced_database, '.', referenced_schema, + '.', referenced_object_name + ) AS view_upstream, + concat( + referencing_database, '.', referencing_schema, + '.', referencing_object_name + ) AS downstream_view +FROM + snowflake.account_usage.object_dependencies +WHERE + referencing_object_domain in ('VIEW', 'MATERIALIZED VIEW') + """ + + assert self._lineage_map is not None + num_edges: int = 0 + + try: + for db_row in engine.execute(view_upstream_lineage_query): + # Process UpstreamTable/View/ExternalTable/Materialized View->View edge. + view_upstream: str = db_row["view_upstream"].lower() + view_name: str = db_row["downstream_view"].lower() + if not self._is_dataset_allowed(dataset_name=view_name, is_view=True): + continue + # key is the downstream view name + self._lineage_map[view_name].append( + # (, , ) + (view_upstream, "[]", "[]") + ) + num_edges += 1 + logger.debug( + f"Upstream->View: Lineage[View(Down)={view_name}]:Upstream={view_upstream}" + ) + except Exception as e: + self.warn( + logger, + "view_upstream_lineage", + "Extracting the upstream view lineage from Snowflake failed." + + f"Please check your permissions. Continuing...\nError was {e}.", + ) + logger.info(f"A total of {num_edges} View upstream edges found.") + self.report.num_table_to_view_edges_scanned = num_edges + + def _populate_view_downstream_lineage( + self, engine: sqlalchemy.engine.Engine + ) -> None: + # This query captures the downstream table lineage for views. + # See https://docs.snowflake.com/en/sql-reference/account-usage/access_history.html#usage-notes for current limitations on capturing the lineage for views. + # Eg: For viewA->viewB->ViewC->TableD, snowflake does not yet log intermediate view logs, resulting in only the viewA->TableD edge. + view_lineage_query: str = """ +WITH view_lineage_history AS ( + SELECT + vu.value : "objectName" AS view_name, + vu.value : "objectDomain" AS view_domain, + vu.value : "columns" AS view_columns, + w.value : "objectName" AS downstream_table_name, + w.value : "objectDomain" AS downstream_table_domain, + w.value : "columns" AS downstream_table_columns, + t.query_start_time AS query_start_time + FROM + ( + SELECT + * + FROM + snowflake.account_usage.access_history + ) t, + lateral flatten(input => t.DIRECT_OBJECTS_ACCESSED) vu, + lateral flatten(input => t.OBJECTS_MODIFIED) w + WHERE + vu.value : "objectId" IS NOT NULL + AND w.value : "objectId" IS NOT NULL + AND w.value : "objectName" NOT LIKE '%.GE_TMP_%' + AND w.value : "objectName" NOT LIKE '%.GE_TEMP_%' + AND t.query_start_time >= to_timestamp_ltz({start_time_millis}, 3) + AND t.query_start_time < to_timestamp_ltz({end_time_millis}, 3) +) +SELECT + view_name, + view_columns, + downstream_table_name, + downstream_table_columns +FROM + view_lineage_history +WHERE + view_domain in ('View', 'Materialized view') + QUALIFY ROW_NUMBER() OVER ( + PARTITION BY view_name, + downstream_table_name + ORDER BY + query_start_time DESC + ) = 1 + """.format( + start_time_millis=int(self.config.start_time.timestamp() * 1000) + if not self.config.ignore_start_time_lineage + else 0, + end_time_millis=int(self.config.end_time.timestamp() * 1000), + ) + + assert self._lineage_map is not None + self.report.num_view_to_table_edges_scanned = 0 + + try: + db_rows = engine.execute(view_lineage_query) + except Exception as e: + self.warn( + logger, + "view_downstream_lineage", + f"Extracting the view lineage from Snowflake failed." + f"Please check your permissions. Continuing...\nError was {e}.", + ) + else: + for db_row in db_rows: + view_name: str = db_row["view_name"].lower().replace('"', "") + if not self._is_dataset_allowed(dataset_name=view_name, is_view=True): + continue + downstream_table: str = ( + db_row["downstream_table_name"].lower().replace('"', "") + ) + # Capture view->downstream table lineage. + self._lineage_map[downstream_table].append( + # (, , ) + ( + view_name, + db_row["view_columns"], + db_row["downstream_table_columns"], + ) + ) + self.report.num_view_to_table_edges_scanned += 1 + + logger.debug( + f"View->Table: Lineage[Table(Down)={downstream_table}]:View(Up)={self._lineage_map[downstream_table]}" + ) + + logger.info( + f"Found {self.report.num_view_to_table_edges_scanned} View->Table edges." + ) + + def get_metadata_engine_for_lineage(self) -> sqlalchemy.engine.Engine: + + username = self.config.username + password = self.config.password + role = self.config.role + + url = self.config.get_sql_alchemy_url( + database=None, username=username, password=password, role=role + ) + logger.debug(f"sql_alchemy_url={url}") + if self.config.authentication_type == "OAUTH_AUTHENTICATOR": + return create_engine( + url, + creator=self.config.get_oauth_connection, + **self.config.get_options(), + ) + else: + return create_engine( + url, + **self.config.get_options(), + ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py new file mode 100644 index 00000000000000..ed5e3341f70ffb --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py @@ -0,0 +1,468 @@ +import logging +from dataclasses import dataclass, field +from datetime import datetime +from typing import Dict, List, Optional + +from snowflake.connector import DictCursor, SnowflakeConnection + +logger: logging.Logger = logging.getLogger(__name__) + + +@dataclass +class SnowflakePK: + name: str + column_names: List[str] + + +@dataclass +class SnowflakeFK: + name: str + column_names: List[str] + referred_database: str + referred_schema: str + referred_table: str + referred_column_names: List[str] + + +@dataclass +class SnowflakeColumn: + name: str + ordinal_position: int + is_nullable: bool + data_type: str + comment: str + + +@dataclass +class SnowflakeTable: + name: str + created: datetime + last_altered: datetime + size_in_bytes: int + rows_count: int + comment: str + clustering_key: str + pk: Optional[SnowflakePK] = None + columns: List[SnowflakeColumn] = field(default_factory=list) + foreign_keys: List[SnowflakeFK] = field(default_factory=list) + + +@dataclass +class SnowflakeView: + name: str + created: datetime + last_altered: datetime + comment: str + view_definition: str + columns: List[SnowflakeColumn] = field(default_factory=list) + + +@dataclass +class SnowflakeSchema: + name: str + created: datetime + last_altered: datetime + comment: str + tables: List[SnowflakeTable] = field(default_factory=list) + views: List[SnowflakeView] = field(default_factory=list) + + +@dataclass +class SnowflakeDatabase: + name: str + created: datetime + comment: str + schemas: List[SnowflakeSchema] = field(default_factory=list) + + +class SnowflakeQuery: + @staticmethod + def show_databases() -> str: + return "show databases" + + @staticmethod + def use_database(db_name: str) -> str: + return f'use database "{db_name}"' + + @staticmethod + def schemas_for_database(db_name: Optional[str]) -> str: + db_clause = f'"{db_name}".' if db_name is not None else "" + return f""" + SELECT schema_name as "schema_name", + created as "created", + last_altered as "last_altered", + comment as "comment" + from {db_clause}information_schema.schemata + WHERE schema_name != 'INFORMATION_SCHEMA' + order by schema_name""" + + @staticmethod + def tables_for_database(db_name: Optional[str]) -> str: + db_clause = f'"{db_name}".' if db_name is not None else "" + return f""" + SELECT table_catalog as "table_catalog", + table_schema as "table_schema", + table_name as "table_name", + table_type as "table_type", + created as "created", + last_altered as "last_altered" , + comment as "comment", + row_count as "row_count", + bytes as "bytes", + clustering_key as "clustering_key", + auto_clustering_on as "auto_clustering_on" + FROM {db_clause}information_schema.tables t + WHERE table_schema != 'INFORMATION_SCHEMA' + and table_type in ( 'BASE TABLE', 'EXTERNAL TABLE') + order by table_schema, table_name""" + + @staticmethod + def tables_for_schema(schema_name: str, db_name: Optional[str]) -> str: + db_clause = f'"{db_name}".' if db_name is not None else "" + return f""" + SELECT table_catalog as "table_catalog", + table_schema as "table_schema", + table_name as "table_name", + table_type as "table_type", + created as "created", + last_altered as "last_altered" , + comment as "comment", + row_count as "row_count", + bytes as "bytes", + clustering_key as "clustering_key", + auto_clustering_on as "auto_clustering_on" + FROM {db_clause}information_schema.tables t + where schema_name='{schema_name}' + and table_type in ('BASE TABLE', 'EXTERNAL TABLE') + order by table_schema, table_name""" + + @staticmethod + def views_for_database(db_name: Optional[str]) -> str: + db_clause = f'"{db_name}".' if db_name is not None else "" + return f""" + SELECT table_catalog as "table_catalog", + table_schema as "table_schema", + table_name as "table_name", + created as "created", + last_altered as "last_altered", + comment as "comment", + view_definition as "view_definition" + FROM {db_clause}information_schema.views t + WHERE table_schema != 'INFORMATION_SCHEMA' + order by table_schema, table_name""" + + @staticmethod + def views_for_schema(schema_name: str, db_name: Optional[str]) -> str: + db_clause = f'"{db_name}".' if db_name is not None else "" + return f""" + SELECT table_catalog as "table_catalog", + table_schema as "table_schema", + table_name as "table_name", + created as "created", + last_altered as "last_altered", + comment as "comment", + view_definition as "view_definition" + FROM {db_clause}information_schema.views t + where schema_name='{schema_name}' + order by table_schema, table_name""" + + @staticmethod + def columns_for_schema(schema_name: str, db_name: Optional[str]) -> str: + db_clause = f'"{db_name}".' if db_name is not None else "" + return f""" + select + table_catalog as "table_catalog", + table_schema as "table_schema", + table_name as "table_name", + column_name as "column_name", + ordinal_position as "ordinal_position", + is_nullable as "is_nullable", + data_type as "data_type", + comment as "comment", + character_maximum_length as "character_maximum_length", + numeric_precision as "numeric_precision", + numeric_scale as "numeric_scale", + column_default as "column_default", + is_identity as "is_identity" + from {db_clause}information_schema.columns + WHERE table_schema='{schema_name}' + ORDER BY ordinal_position""" + + @staticmethod + def columns_for_table( + table_name: str, schema_name: str, db_name: Optional[str] + ) -> str: + db_clause = f'"{db_name}".' if db_name is not None else "" + return f""" + select + table_catalog as "table_catalog", + table_schema as "table_schema", + table_name as "table_name", + column_name as "column_name", + ordinal_position as "ordinal_position", + is_nullable as "is_nullable", + data_type as "data_type", + comment as "comment", + character_maximum_length as "character_maximum_length", + numeric_precision as "numeric_precision", + numeric_scale as "numeric_scale", + column_default as "column_default", + is_identity as "is_identity" + from {db_clause}information_schema.columns + WHERE table_schema='{schema_name}' and table_name='{table_name}' + ORDER BY ordinal_position""" + + @staticmethod + def show_primary_keys_for_schema(schema_name: str, db_name: str) -> str: + return f""" + show primary keys in schema "{db_name}"."{schema_name}" """ + + @staticmethod + def show_foreign_keys_for_schema(schema_name: str, db_name: str) -> str: + return f""" + show imported keys in schema "{db_name}"."{schema_name}" """ + + +class SnowflakeDataDictionary: + def query(self, conn, query): + logger.debug("Query : {}".format(query)) + resp = conn.cursor(DictCursor).execute(query) + return resp + + def get_databases(self, conn: SnowflakeConnection) -> List[SnowflakeDatabase]: + + databases: List[SnowflakeDatabase] = [] + + cur = self.query( + conn, + SnowflakeQuery.show_databases(), + ) + + for database in cur: + snowflake_db = SnowflakeDatabase( + name=database["name"], + created=database["created_on"], + comment=database["comment"], + ) + databases.append(snowflake_db) + + return databases + + def get_schemas_for_database( + self, conn: SnowflakeConnection, db_name: str + ) -> List[SnowflakeSchema]: + + snowflake_schemas = [] + + cur = self.query( + conn, + SnowflakeQuery.schemas_for_database(db_name), + ) + for schema in cur: + snowflake_schema = SnowflakeSchema( + name=schema["schema_name"], + created=schema["created"], + last_altered=schema["last_altered"], + comment=schema["comment"], + ) + snowflake_schemas.append(snowflake_schema) + return snowflake_schemas + + def get_tables_for_database( + self, conn: SnowflakeConnection, db_name: str + ) -> Optional[Dict[str, List[SnowflakeTable]]]: + tables: Dict[str, List[SnowflakeTable]] = {} + try: + cur = self.query( + conn, + SnowflakeQuery.tables_for_database(db_name), + ) + except Exception as e: + logger.debug(e) + # Error - Information schema query returned too much data. Please repeat query with more selective predicates. + return None + + for table in cur: + if table["table_schema"] not in tables: + tables[table["table_schema"]] = [] + tables[table["table_schema"]].append( + SnowflakeTable( + name=table["table_name"], + created=table["created"], + last_altered=table["last_altered"], + size_in_bytes=table["bytes"], + rows_count=table["row_count"], + comment=table["comment"], + clustering_key=table["clustering_key"], + ) + ) + return tables + + def get_tables_for_schema( + self, conn: SnowflakeConnection, schema_name: str, db_name: str + ) -> List[SnowflakeTable]: + tables: List[SnowflakeTable] = [] + + cur = self.query( + conn, + SnowflakeQuery.tables_for_schema(schema_name, db_name), + ) + + for table in cur: + tables.append( + SnowflakeTable( + name=table["table_name"], + created=table["created"], + last_altered=table["last_altered"], + size_in_bytes=table["bytes"], + rows_count=table["row_count"], + comment=table["comment"], + clustering_key=table["clustering_key"], + ) + ) + return tables + + def get_views_for_database( + self, conn: SnowflakeConnection, db_name: str + ) -> Optional[Dict[str, List[SnowflakeView]]]: + views: Dict[str, List[SnowflakeView]] = {} + try: + cur = self.query(conn, SnowflakeQuery.views_for_database(db_name)) + except Exception as e: + logger.debug(e) + # Error - Information schema query returned too much data. Please repeat query with more selective predicates. + return None + + for table in cur: + if table["table_schema"] not in views: + views[table["table_schema"]] = [] + views[table["table_schema"]].append( + SnowflakeView( + name=table["table_name"], + created=table["created"], + last_altered=table["last_altered"], + comment=table["comment"], + view_definition=table["view_definition"], + ) + ) + return views + + def get_views_for_schema( + self, conn: SnowflakeConnection, schema_name: str, db_name: str + ) -> List[SnowflakeView]: + views: List[SnowflakeView] = [] + + cur = self.query(conn, SnowflakeQuery.views_for_schema(schema_name, db_name)) + for table in cur: + views.append( + SnowflakeView( + name=table["table_name"], + created=table["created"], + last_altered=table["last_altered"], + comment=table["comment"], + view_definition=table["view_definition"], + ) + ) + return views + + def get_columns_for_schema( + self, conn: SnowflakeConnection, schema_name: str, db_name: str + ) -> Optional[Dict[str, List[SnowflakeColumn]]]: + columns: Dict[str, List[SnowflakeColumn]] = {} + try: + cur = self.query( + conn, SnowflakeQuery.columns_for_schema(schema_name, db_name) + ) + except Exception as e: + logger.debug(e) + # Error - Information schema query returned too much data. + # Please repeat query with more selective predicates. + return None + + for column in cur: + if column["table_name"] not in columns: + columns[column["table_name"]] = [] + columns[column["table_name"]].append( + SnowflakeColumn( + name=column["column_name"], + ordinal_position=column["ordinal_position"], + is_nullable=column["is_nullable"] == "YES", + data_type=column["data_type"], + comment=column["comment"], + ) + ) + return columns + + def get_columns_for_table( + self, conn: SnowflakeConnection, table_name: str, schema_name: str, db_name: str + ) -> List[SnowflakeColumn]: + columns: List[SnowflakeColumn] = [] + + cur = self.query( + conn, + SnowflakeQuery.columns_for_table(table_name, schema_name, db_name), + ) + + for column in cur: + columns.append( + SnowflakeColumn( + name=column["column_name"], + ordinal_position=column["ordinal_position"], + is_nullable=column["is_nullable"] == "YES", + data_type=column["data_type"], + comment=column["comment"], + ) + ) + return columns + + def get_pk_constraints_for_schema( + self, conn: SnowflakeConnection, schema_name: str, db_name: str + ) -> Dict[str, SnowflakePK]: + constraints: Dict[str, SnowflakePK] = {} + cur = self.query( + conn, + SnowflakeQuery.show_primary_keys_for_schema(schema_name, db_name), + ) + + for row in cur: + if row["table_name"] not in constraints: + constraints[row["table_name"]] = SnowflakePK( + name=row["constraint_name"], column_names=[] + ) + constraints[row["table_name"]].column_names.append(row["column_name"]) + return constraints + + def get_fk_constraints_for_schema( + self, conn: SnowflakeConnection, schema_name: str, db_name: str + ) -> Dict[str, List[SnowflakeFK]]: + constraints: Dict[str, List[SnowflakeFK]] = {} + fk_constraints_map: Dict[str, SnowflakeFK] = {} + + cur = self.query( + conn, + SnowflakeQuery.show_foreign_keys_for_schema(schema_name, db_name), + ) + + for row in cur: + if row["fk_name"] not in constraints: + fk_constraints_map[row["fk_name"]] = SnowflakeFK( + name=row["fk_name"], + column_names=[], + referred_database=row["pk_database_name"], + referred_schema=row["pk_schema_name"], + referred_table=row["pk_table_name"], + referred_column_names=[], + ) + + if row["fk_table_name"] not in constraints: + constraints[row["fk_table_name"]] = [] + + fk_constraints_map[row["fk_name"]].column_names.append( + row["fk_column_name"] + ) + fk_constraints_map[row["fk_name"]].referred_column_names.append( + row["pk_column_name"] + ) + constraints[row["fk_table_name"]].append(fk_constraints_map[row["fk_name"]]) + + return constraints diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py new file mode 100644 index 00000000000000..20d5f8a43123eb --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -0,0 +1,700 @@ +import logging +from datetime import datetime +from typing import Dict, Iterable, List, Optional, Tuple, Union, cast + +from avrogen.dict_wrapper import DictWrapper +from pydantic import Field, root_validator +from snowflake.connector import SnowflakeConnection + +from datahub.emitter.mce_builder import ( + make_data_platform_urn, + make_dataset_urn, + make_dataset_urn_with_platform_instance, + make_schema_field_urn, +) +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.api.common import PipelineContext, WorkUnit +from datahub.ingestion.api.decorators import ( + SupportStatus, + capability, + config_class, + platform_name, + support_status, +) +from datahub.ingestion.api.source import ( + Source, + SourceCapability, + SourceReport, + TestableSource, + TestConnectionReport, +) +from datahub.ingestion.api.workunit import MetadataWorkUnit +from datahub.ingestion.source.snowflake.snowflake_lineage import ( + SnowflakeLineageExtractor, +) +from datahub.ingestion.source.snowflake.snowflake_schema import ( + SnowflakeColumn, + SnowflakeDatabase, + SnowflakeDataDictionary, + SnowflakeFK, + SnowflakePK, + SnowflakeQuery, + SnowflakeTable, + SnowflakeView, +) +from datahub.ingestion.source.sql.snowflake import SnowflakeSource +from datahub.ingestion.source.sql.sql_common import ( + SQLAlchemySource, + SQLAlchemyStatefulIngestionConfig, +) +from datahub.ingestion.source_config.sql.snowflake import ( + SnowflakeConfig, + SnowflakeProvisionRoleConfig, +) +from datahub.ingestion.source_report.sql.snowflake import SnowflakeReport +from datahub.metadata.com.linkedin.pegasus2avro.common import Status, SubTypes +from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( + DatasetProfile, + DatasetProperties, + UpstreamLineage, + ViewProperties, +) +from datahub.metadata.com.linkedin.pegasus2avro.events.metadata import ChangeType +from datahub.metadata.com.linkedin.pegasus2avro.schema import ( + ArrayType, + BooleanType, + BytesType, + DateType, + ForeignKeyConstraint, + MySqlDDL, + NullType, + NumberType, + RecordType, + SchemaField, + SchemaFieldDataType, + SchemaMetadata, + StringType, + TimeType, +) + +logger: logging.Logger = logging.getLogger(__name__) + +# https://docs.snowflake.com/en/sql-reference/intro-summary-data-types.html +SNOWFLAKE_FIELD_TYPE_MAPPINGS = { + "DATE": DateType, + "BIGINT": NumberType, + "BINARY": BytesType, + # 'BIT': BIT, + "BOOLEAN": BooleanType, + "CHAR": NullType, + "CHARACTER": NullType, + "DATETIME": TimeType, + "DEC": NumberType, + "DECIMAL": NumberType, + "DOUBLE": NumberType, + "FIXED": NumberType, + "FLOAT": NumberType, + "INT": NumberType, + "INTEGER": NumberType, + "NUMBER": NumberType, + # 'OBJECT': ? + "REAL": NumberType, + "BYTEINT": NumberType, + "SMALLINT": NumberType, + "STRING": StringType, + "TEXT": StringType, + "TIME": TimeType, + "TIMESTAMP": TimeType, + "TIMESTAMP_TZ": TimeType, + "TIMESTAMP_LTZ": TimeType, + "TIMESTAMP_NTZ": TimeType, + "TINYINT": NumberType, + "VARBINARY": BytesType, + "VARCHAR": StringType, + "VARIANT": RecordType, + "OBJECT": NullType, + "ARRAY": ArrayType, + "GEOGRAPHY": NullType, +} + + +@platform_name("Snowflake") +@config_class(SnowflakeConfig) +@support_status(SupportStatus.INCUBATING) +@capability(SourceCapability.PLATFORM_INSTANCE, "Enabled by default") +@capability(SourceCapability.DOMAINS, "Supported via the `domain` config field") +@capability(SourceCapability.CONTAINERS, "Enabled by default") +@capability(SourceCapability.SCHEMA_METADATA, "Enabled by default") +@capability( + SourceCapability.DATA_PROFILING, + "Optionally enabled via configuration, only table level profiling is supported", +) +@capability(SourceCapability.DESCRIPTIONS, "Enabled by default") +@capability(SourceCapability.LINEAGE_COARSE, "Optionally enabled via configuration") +@capability(SourceCapability.DELETION_DETECTION, "Coming soon", supported=False) +class SnowflakeV2Config(SnowflakeConfig): + _convert_urns_to_lowercase: bool = Field( + default=True, + exclude=True, + description="Not supported", + ) + + check_role_grants: bool = Field( + default=False, + exclude=True, + description="Not supported", + ) + + provision_role: Optional[SnowflakeProvisionRoleConfig] = Field( + default=None, exclude=True, description="Not supported" + ) + + stateful_ingestion: Optional[SQLAlchemyStatefulIngestionConfig] = Field( + default=None, exclude=True, description="Not supported" + ) + + @root_validator(pre=False) + def validate_unsupported_configs(cls, values: Dict) -> Dict: + value = values.get("stateful_ingestion") + if value is not None and value.enabled: + raise ValueError( + "Stateful ingestion is currently not supported. Set `stateful_ingestion.enabled` to False" + ) + + value = values.get("provision_role") + if value is not None and value.enabled: + raise ValueError( + "Provision role is currently not supported. Set `provision_role.enabled` to False." + ) + + value = values.get("profiling") + if value is not None and value.enabled and not value.profile_table_level_only: + raise ValueError( + "Only table level profiling is supported. Set `profiling.profile_table_level_only` to True.", + ) + + value = values.get("check_role_grants") + if value is not None and value: + raise ValueError( + "Check role grants is not supported. Set `check_role_grants` to False.", + ) + return values + + +@platform_name("Snowflake") +@config_class(SnowflakeV2Config) +class SnowflakeV2Source(TestableSource): + def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): + super().__init__(ctx) + self.config: SnowflakeV2Config = config + self.report: SnowflakeReport = SnowflakeReport() + self.platform: str = "snowflake" + self.sql_common: SQLAlchemySource = SQLAlchemySource(config, ctx, self.platform) + + # For database, schema, tables, views, etc + self.data_dictionary = SnowflakeDataDictionary() + + # For lineage + self.lineage_extractor = SnowflakeLineageExtractor(config, self.report) + + # Currently caching using instance variables + # TODO - rewrite cache for readability or use out of the box solution + self.db_tables: Dict[str, Optional[Dict[str, List[SnowflakeTable]]]] = {} + self.db_views: Dict[str, Optional[Dict[str, List[SnowflakeView]]]] = {} + + # For column related queries and constraints, we currently query at schema level + # In future, we may consider using queries and caching at database level first + self.schema_columns: Dict[ + Tuple[str, str], Optional[Dict[str, List[SnowflakeColumn]]] + ] = {} + self.schema_pk_constraints: Dict[Tuple[str, str], Dict[str, SnowflakePK]] = {} + self.schema_fk_constraints: Dict[ + Tuple[str, str], Dict[str, List[SnowflakeFK]] + ] = {} + + @classmethod + def create(cls, config_dict: dict, ctx: PipelineContext) -> "Source": + config = SnowflakeV2Config.parse_obj(config_dict) + return cls(ctx, config) + + @staticmethod + def test_connection(config_dict: dict) -> TestConnectionReport: + return SnowflakeSource.test_connection(config_dict) + + def snowflake_identifier(self, identifier: str) -> str: + # to be in in sync with older connector, convert name to lowercase + if self.config._convert_urns_to_lowercase: + return identifier.lower() + return identifier + + def get_workunits(self) -> Iterable[WorkUnit]: + + conn: SnowflakeConnection = self.config.get_connection() + self.add_config_to_report() + + self.inspect_session_metadata(conn) + databases: List[SnowflakeDatabase] = self.data_dictionary.get_databases(conn) + for snowflake_db in databases: + db_name = snowflake_db.name + + if not self.config.database_pattern.allowed(db_name): + self.report.report_dropped(db_name) + continue + + database_workunits = self.sql_common.gen_database_containers( + self.snowflake_identifier(db_name) + ) + + for wu in database_workunits: + self.report.report_workunit(wu) + yield wu + + # Use database and extract metadata from its information_schema + # If this query fails, it means, user does not have usage access on database + try: + self.data_dictionary.query(conn, SnowflakeQuery.use_database(db_name)) + except Exception as e: + self.report.report_warning( + db_name, + f"unable to get metadata information for database {db_name} due to an error -> {e}", + ) + self.report.report_dropped(db_name) + continue + + snowflake_db.schemas = self.data_dictionary.get_schemas_for_database( + conn, db_name + ) + + for snowflake_schema in snowflake_db.schemas: + schema_name = snowflake_schema.name + if not self.config.schema_pattern.allowed(schema_name): + self.report.report_dropped(f"{schema_name}.*") + continue + + schema_workunits = self.sql_common.gen_schema_containers( + self.snowflake_identifier(schema_name), + self.snowflake_identifier(db_name), + ) + + for wu in schema_workunits: + self.report.report_workunit(wu) + yield wu + + if self.config.include_tables: + snowflake_schema.tables = self.get_tables_for_schema( + conn, schema_name, db_name + ) + + for table in snowflake_schema.tables: + table_identifier = self.get_identifier( + table.name, schema_name, db_name + ) + + self.report.report_entity_scanned(table_identifier) + + if not self.config.table_pattern.allowed(table_identifier): + self.report.report_dropped(table_identifier) + continue + + table.columns = self.get_columns_for_table( + conn, table.name, schema_name, db_name + ) + table.pk = self.get_pk_constraints_for_table( + conn, table.name, schema_name, db_name + ) + table.foreign_keys = self.get_fk_constraints_for_table( + conn, table.name, schema_name, db_name + ) + dataset_name = self.get_identifier( + table.name, schema_name, db_name + ) + + # TODO: rewrite lineage extractor to honour _convert_urns_to_lowercase=False config + # Currently it generates backward compatible (lowercase) urns only + lineage_info = ( + self.lineage_extractor._get_upstream_lineage_info( + dataset_name + ) + ) + + table_workunits = self.gen_dataset_workunits( + table, schema_name, db_name, lineage_info + ) + for wu in table_workunits: + self.report.report_workunit(wu) + yield wu + + if self.config.include_views: + snowflake_schema.views = self.get_views_for_schema( + conn, schema_name, db_name + ) + + for view in snowflake_schema.views: + table_identifier = self.get_identifier( + view.name, schema_name, db_name + ) + + self.report.report_entity_scanned(table_identifier, "view") + + if not self.config.view_pattern.allowed(table_identifier): + self.report.report_dropped(table_identifier) + continue + + view.columns = self.get_columns_for_table( + conn, view.name, schema_name, db_name + ) + dataset_name = self.get_identifier( + view.name, schema_name, db_name + ) + lineage_info = ( + self.lineage_extractor._get_upstream_lineage_info( + dataset_name + ) + ) + view_workunits = self.gen_dataset_workunits( + view, schema_name, db_name, lineage_info + ) + for wu in view_workunits: + self.report.report_workunit(wu) + yield wu + + def gen_dataset_workunits( + self, + table: Union[SnowflakeTable, SnowflakeView], + schema_name: str, + db_name: str, + lineage_info: Optional[Tuple[UpstreamLineage, Dict[str, str]]], + ) -> Iterable[MetadataWorkUnit]: + dataset_name = self.get_identifier(table.name, schema_name, db_name) + dataset_urn = make_dataset_urn_with_platform_instance( + self.platform, + dataset_name, + self.config.platform_instance, + self.config.env, + ) + if lineage_info is not None: + upstream_lineage, upstream_column_props = lineage_info + else: + upstream_column_props = {} + upstream_lineage = None + + status = Status(removed=False) + yield self.wrap_aspect_as_workunit("dataset", dataset_urn, "status", status) + + foreign_keys: Optional[List[ForeignKeyConstraint]] = None + if isinstance(table, SnowflakeTable) and len(table.foreign_keys) > 0: + foreign_keys = [] + for fk in table.foreign_keys: + foreign_dataset = make_dataset_urn( + self.platform, + self.get_identifier( + fk.referred_table, fk.referred_schema, fk.referred_database + ), + self.config.env, + ) + foreign_keys.append( + ForeignKeyConstraint( + name=fk.name, + foreignDataset=foreign_dataset, + foreignFields=[ + make_schema_field_urn( + foreign_dataset, self.snowflake_identifier(col) + ) + for col in fk.referred_column_names + ], + sourceFields=[ + make_schema_field_urn( + dataset_urn, self.snowflake_identifier(col) + ) + for col in fk.column_names + ], + ) + ) + + schema_metadata = SchemaMetadata( + schemaName=dataset_name, + platform=make_data_platform_urn(self.platform), + version=0, + hash="", + platformSchema=MySqlDDL(tableSchema=""), + fields=[ + SchemaField( + fieldPath=self.snowflake_identifier(col.name), + type=SchemaFieldDataType( + SNOWFLAKE_FIELD_TYPE_MAPPINGS.get(col.data_type, NullType)() + ), + # NOTE: nativeDataType will not be in sync with older connector + nativeDataType=col.data_type, + description=col.comment, + nullable=col.is_nullable, + isPartOfKey=col.name in table.pk.column_names + if isinstance(table, SnowflakeTable) and table.pk is not None + else None, + ) + for col in table.columns + ], + foreignKeys=foreign_keys, + ) + yield self.wrap_aspect_as_workunit( + "dataset", dataset_urn, "schemaMetadata", schema_metadata + ) + + dataset_properties = DatasetProperties( + name=table.name, + description=table.comment, + qualifiedName=dataset_name, + customProperties={**upstream_column_props}, + ) + yield self.wrap_aspect_as_workunit( + "dataset", dataset_urn, "datasetProperties", dataset_properties + ) + + yield from self.sql_common.add_table_to_schema_container( + dataset_urn, + self.snowflake_identifier(db_name), + self.snowflake_identifier(schema_name), + ) + dpi_aspect = self.sql_common.get_dataplatform_instance_aspect( + dataset_urn=dataset_urn + ) + if dpi_aspect: + yield dpi_aspect + + subtypes_aspect = MetadataWorkUnit( + id=f"{dataset_name}-subtypes", + mcp=MetadataChangeProposalWrapper( + entityType="dataset", + changeType=ChangeType.UPSERT, + entityUrn=dataset_urn, + aspectName="subTypes", + aspect=SubTypes( + typeNames=["view"] + if isinstance(table, SnowflakeView) + else ["table"] + ), + ), + ) + yield subtypes_aspect + + yield from self.sql_common._get_domain_wu( + dataset_name=dataset_name, + entity_urn=dataset_urn, + entity_type="dataset", + sql_config=self.config, + ) + + if upstream_lineage is not None: + # Emit the lineage work unit + lineage_mcpw = MetadataChangeProposalWrapper( + entityType="dataset", + changeType=ChangeType.UPSERT, + entityUrn=dataset_urn, + aspectName="upstreamLineage", + aspect=upstream_lineage, + ) + lineage_wu = MetadataWorkUnit( + id=f"{self.platform}-{lineage_mcpw.entityUrn}-{lineage_mcpw.aspectName}", + mcp=lineage_mcpw, + ) + yield lineage_wu + + if isinstance(table, SnowflakeTable) and self.config.profiling.enabled: + if self.config.profiling.allow_deny_patterns.allowed(dataset_name): + # Emit the profile work unit + dataset_profile = DatasetProfile( + timestampMillis=round(datetime.now().timestamp() * 1000), + columnCount=len(table.columns), + rowCount=table.rows_count, + ) + profile_mcpw = MetadataChangeProposalWrapper( + entityType="dataset", + changeType=ChangeType.UPSERT, + entityUrn=dataset_urn, + aspectName="datasetProfile", + aspect=dataset_profile, + ) + profile_wu = MetadataWorkUnit( + id=f"{self.platform}-{profile_mcpw.entityUrn}-{profile_mcpw.aspectName}", + mcp=profile_mcpw, + ) + self.report.report_entity_profiled(dataset_name) + yield profile_wu + else: + self.report.report_dropped(f"Profile for {dataset_name}") + + if isinstance(table, SnowflakeView): + view = cast(SnowflakeView, table) + view_definition_string = view.view_definition + view_properties_aspect = ViewProperties( + materialized=False, viewLanguage="SQL", viewLogic=view_definition_string + ) + view_properties_wu = MetadataWorkUnit( + id=f"{view.name}-viewProperties", + mcp=MetadataChangeProposalWrapper( + entityType="dataset", + changeType=ChangeType.UPSERT, + entityUrn=dataset_urn, + aspectName="viewProperties", + aspect=view_properties_aspect, + ), + ) + yield view_properties_wu + + def get_identifier(self, table_name: str, schema_name: str, db_name: str) -> str: + return self.snowflake_identifier(f"{db_name}.{schema_name}.{table_name}") + + def get_report(self) -> SourceReport: + return self.report + + def get_tables_for_schema( + self, conn: SnowflakeConnection, schema_name: str, db_name: str + ) -> List[SnowflakeTable]: + + if db_name not in self.db_tables.keys(): + tables = self.data_dictionary.get_tables_for_database(conn, db_name) + self.db_tables[db_name] = tables + else: + tables = self.db_tables[db_name] + + # get all tables for database failed, + # falling back to get tables for schema + if tables is None: + return self.data_dictionary.get_tables_for_schema( + conn, schema_name, db_name + ) + + # Some schema may not have any table + return tables.get(schema_name, []) + + def get_views_for_schema( + self, conn: SnowflakeConnection, schema_name: str, db_name: str + ) -> List[SnowflakeView]: + + if db_name not in self.db_views.keys(): + views = self.data_dictionary.get_views_for_database(conn, db_name) + self.db_views[db_name] = views + else: + views = self.db_views[db_name] + + # get all views for database failed, + # falling back to get views for schema + if views is None: + return self.data_dictionary.get_views_for_schema(conn, schema_name, db_name) + + # Some schema may not have any table + return views.get(schema_name, []) + + def get_columns_for_table( + self, conn: SnowflakeConnection, table_name: str, schema_name: str, db_name: str + ) -> List[SnowflakeColumn]: + + if (db_name, schema_name) not in self.schema_columns.keys(): + columns = self.data_dictionary.get_columns_for_schema( + conn, schema_name, db_name + ) + self.schema_columns[(db_name, schema_name)] = columns + else: + columns = self.schema_columns[(db_name, schema_name)] + + # get all columns for schema failed, + # falling back to get columns for table + if columns is None: + return self.data_dictionary.get_columns_for_table( + conn, table_name, schema_name, db_name + ) + + # Access to table but none of its columns - is this possible ? + return columns.get(table_name, []) + + def get_pk_constraints_for_table( + self, conn: SnowflakeConnection, table_name: str, schema_name: str, db_name: str + ) -> Optional[SnowflakePK]: + + if (db_name, schema_name) not in self.schema_pk_constraints.keys(): + constraints = self.data_dictionary.get_pk_constraints_for_schema( + conn, schema_name, db_name + ) + self.schema_pk_constraints[(db_name, schema_name)] = constraints + else: + constraints = self.schema_pk_constraints[(db_name, schema_name)] + + # Access to table but none of its constraints - is this possible ? + return constraints.get(table_name) + + def get_fk_constraints_for_table( + self, conn: SnowflakeConnection, table_name: str, schema_name: str, db_name: str + ) -> List[SnowflakeFK]: + + if (db_name, schema_name) not in self.schema_fk_constraints.keys(): + constraints = self.data_dictionary.get_fk_constraints_for_schema( + conn, schema_name, db_name + ) + self.schema_fk_constraints[(db_name, schema_name)] = constraints + else: + constraints = self.schema_fk_constraints[(db_name, schema_name)] + + # Access to table but none of its constraints - is this possible ? + return constraints.get(table_name, []) + + def add_config_to_report(self): + self.report.cleaned_account_id = self.config.get_account() + self.report.ignore_start_time_lineage = self.config.ignore_start_time_lineage + self.report.upstream_lineage_in_report = self.config.upstream_lineage_in_report + if not self.report.ignore_start_time_lineage: + self.report.lineage_start_time = self.config.start_time + self.report.lineage_end_time = self.config.end_time + self.report.check_role_grants = self.config.check_role_grants + + def warn(self, log: logging.Logger, key: str, reason: str) -> None: + self.report.report_warning(key, reason) + log.warning(f"{key} => {reason}") + + def inspect_session_metadata(self, conn: SnowflakeConnection) -> None: + try: + logger.info("Checking current version") + for db_row in self.data_dictionary.query(conn, "select CURRENT_VERSION()"): + self.report.saas_version = db_row["CURRENT_VERSION()"] + except Exception as e: + self.report.report_failure("version", f"Error: {e}") + try: + logger.info("Checking current role") + for db_row in self.data_dictionary.query(conn, "select CURRENT_ROLE()"): + self.report.role = db_row["CURRENT_ROLE()"] + except Exception as e: + self.report.report_failure("version", f"Error: {e}") + try: + logger.info("Checking current warehouse") + for db_row in self.data_dictionary.query( + conn, "select CURRENT_WAREHOUSE()" + ): + self.report.default_warehouse = db_row["CURRENT_WAREHOUSE()"] + except Exception as e: + self.report.report_failure("current_warehouse", f"Error: {e}") + try: + logger.info("Checking current database") + for db_row in self.data_dictionary.query(conn, "select CURRENT_DATABASE()"): + self.report.default_db = db_row["CURRENT_DATABASE()"] + except Exception as e: + self.report.report_failure("current_database", f"Error: {e}") + try: + logger.info("Checking current schema") + for db_row in self.data_dictionary.query(conn, "select CURRENT_SCHEMA()"): + self.report.default_schema = db_row["CURRENT_SCHEMA()"] + except Exception as e: + self.report.report_failure("current_schema", f"Error: {e}") + + def wrap_aspect_as_workunit( + self, entityName: str, entityUrn: str, aspectName: str, aspect: DictWrapper + ) -> MetadataWorkUnit: + wu = MetadataWorkUnit( + id=f"{aspectName}-for-{entityUrn}", + mcp=MetadataChangeProposalWrapper( + entityType=entityName, + entityUrn=entityUrn, + aspectName=aspectName, + aspect=aspect, + changeType=ChangeType.UPSERT, + ), + ) + self.report.report_workunit(wu) + return wu diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/snowflake.py b/metadata-ingestion/src/datahub/ingestion/source/sql/snowflake.py index c3c0518ead788c..e6b5e6578d3809 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/snowflake.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/snowflake.py @@ -1,9 +1,8 @@ import json import logging -from collections import defaultdict from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Union import pydantic @@ -33,7 +32,9 @@ TestConnectionReport, ) from datahub.ingestion.api.workunit import MetadataWorkUnit -from datahub.ingestion.source.aws.s3_util import make_s3_urn +from datahub.ingestion.source.snowflake.snowflake_lineage import ( + SnowflakeLineageExtractor, +) from datahub.ingestion.source.sql.sql_common import ( RecordTypeClass, SQLAlchemySource, @@ -43,11 +44,6 @@ ) from datahub.ingestion.source_config.sql.snowflake import SnowflakeConfig from datahub.ingestion.source_report.sql.snowflake import SnowflakeReport -from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( - DatasetLineageTypeClass, - UpstreamClass, - UpstreamLineage, -) from datahub.metadata.com.linkedin.pegasus2avro.metadata.snapshot import DatasetSnapshot from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent from datahub.metadata.schema_classes import ChangeTypeClass, DatasetPropertiesClass @@ -76,9 +72,8 @@ class SnowflakeSource(SQLAlchemySource, TestableSource): def __init__(self, config: SnowflakeConfig, ctx: PipelineContext): super().__init__(config, ctx, "snowflake") - self._lineage_map: Optional[Dict[str, List[Tuple[str, str, str]]]] = None - self._external_lineage_map: Optional[Dict[str, Set[str]]] = None self.report: SnowflakeReport = SnowflakeReport() + self.lineage_extractor = SnowflakeLineageExtractor(config, self.report) self.config: SnowflakeConfig = config self.provision_role_in_progress: bool = False self.profile_candidates: Dict[str, List[str]] = {} @@ -104,7 +99,7 @@ def query(query): _report: Dict[Union[SourceCapability, str], CapabilityReport] = dict() privileges: List[SnowflakePrivilege] = [] - capabilities: List[SourceCapability] = [c.capability for c in SnowflakeSource.get_capabilities() if c.capability not in (SourceCapability.PLATFORM_INSTANCE, SourceCapability.DOMAINS, SourceCapability.DELETION_DETECTION)] # type: ignore + capabilities: List[SourceCapability] = [c.capability for c in SnowflakeSource.get_capabilities() if c.supported and c.capability not in (SourceCapability.PLATFORM_INSTANCE, SourceCapability.DOMAINS, SourceCapability.DELETION_DETECTION)] # type: ignore cur = query("select current_role()") current_role = [row[0] for row in cur][0] @@ -364,371 +359,6 @@ def get_identifier( ) return f"{self.current_database.lower()}.{regular}" - def _populate_view_upstream_lineage(self, engine: sqlalchemy.engine.Engine) -> None: - # NOTE: This query captures only the upstream lineage of a view (with no column lineage). - # For more details see: https://docs.snowflake.com/en/user-guide/object-dependencies.html#object-dependencies - # and also https://docs.snowflake.com/en/sql-reference/account-usage/access_history.html#usage-notes for current limitations on capturing the lineage for views. - view_upstream_lineage_query: str = """ -SELECT - concat( - referenced_database, '.', referenced_schema, - '.', referenced_object_name - ) AS view_upstream, - concat( - referencing_database, '.', referencing_schema, - '.', referencing_object_name - ) AS downstream_view -FROM - snowflake.account_usage.object_dependencies -WHERE - referencing_object_domain in ('VIEW', 'MATERIALIZED VIEW') - """ - - assert self._lineage_map is not None - num_edges: int = 0 - - try: - for db_row in engine.execute(view_upstream_lineage_query): - # Process UpstreamTable/View/ExternalTable/Materialized View->View edge. - view_upstream: str = db_row["view_upstream"].lower() - view_name: str = db_row["downstream_view"].lower() - if not self._is_dataset_allowed(dataset_name=view_name, is_view=True): - continue - # key is the downstream view name - self._lineage_map[view_name].append( - # (, , ) - (view_upstream, "[]", "[]") - ) - num_edges += 1 - logger.debug( - f"Upstream->View: Lineage[View(Down)={view_name}]:Upstream={view_upstream}" - ) - except Exception as e: - self.warn( - logger, - "view_upstream_lineage", - "Extracting the upstream view lineage from Snowflake failed." - + f"Please check your permissions. Continuing...\nError was {e}.", - ) - logger.info(f"A total of {num_edges} View upstream edges found.") - self.report.num_table_to_view_edges_scanned = num_edges - - def _populate_view_downstream_lineage( - self, engine: sqlalchemy.engine.Engine - ) -> None: - # This query captures the downstream table lineage for views. - # See https://docs.snowflake.com/en/sql-reference/account-usage/access_history.html#usage-notes for current limitations on capturing the lineage for views. - # Eg: For viewA->viewB->ViewC->TableD, snowflake does not yet log intermediate view logs, resulting in only the viewA->TableD edge. - view_lineage_query: str = """ -WITH view_lineage_history AS ( - SELECT - vu.value : "objectName" AS view_name, - vu.value : "objectDomain" AS view_domain, - vu.value : "columns" AS view_columns, - w.value : "objectName" AS downstream_table_name, - w.value : "objectDomain" AS downstream_table_domain, - w.value : "columns" AS downstream_table_columns, - t.query_start_time AS query_start_time - FROM - ( - SELECT - * - FROM - snowflake.account_usage.access_history - ) t, - lateral flatten(input => t.DIRECT_OBJECTS_ACCESSED) vu, - lateral flatten(input => t.OBJECTS_MODIFIED) w - WHERE - vu.value : "objectId" IS NOT NULL - AND w.value : "objectId" IS NOT NULL - AND w.value : "objectName" NOT LIKE '%.GE_TMP_%' - AND w.value : "objectName" NOT LIKE '%.GE_TEMP_%' - AND t.query_start_time >= to_timestamp_ltz({start_time_millis}, 3) - AND t.query_start_time < to_timestamp_ltz({end_time_millis}, 3) -) -SELECT - view_name, - view_columns, - downstream_table_name, - downstream_table_columns -FROM - view_lineage_history -WHERE - view_domain in ('View', 'Materialized view') - QUALIFY ROW_NUMBER() OVER ( - PARTITION BY view_name, - downstream_table_name - ORDER BY - query_start_time DESC - ) = 1 - """.format( - start_time_millis=int(self.config.start_time.timestamp() * 1000) - if not self.config.ignore_start_time_lineage - else 0, - end_time_millis=int(self.config.end_time.timestamp() * 1000), - ) - - assert self._lineage_map is not None - self.report.num_view_to_table_edges_scanned = 0 - - try: - db_rows = engine.execute(view_lineage_query) - except Exception as e: - self.warn( - logger, - "view_downstream_lineage", - f"Extracting the view lineage from Snowflake failed." - f"Please check your permissions. Continuing...\nError was {e}.", - ) - else: - for db_row in db_rows: - view_name: str = db_row["view_name"].lower().replace('"', "") - if not self._is_dataset_allowed(dataset_name=view_name, is_view=True): - continue - downstream_table: str = ( - db_row["downstream_table_name"].lower().replace('"', "") - ) - # Capture view->downstream table lineage. - self._lineage_map[downstream_table].append( - # (, , ) - ( - view_name, - db_row["view_columns"], - db_row["downstream_table_columns"], - ) - ) - self.report.num_view_to_table_edges_scanned += 1 - - logger.debug( - f"View->Table: Lineage[Table(Down)={downstream_table}]:View(Up)={self._lineage_map[downstream_table]}" - ) - - logger.info( - f"Found {self.report.num_view_to_table_edges_scanned} View->Table edges." - ) - - def _populate_view_lineage(self) -> None: - if not self.config.include_view_lineage: - return - engine = self.get_metadata_engine(database=None) - self._populate_view_upstream_lineage(engine) - self._populate_view_downstream_lineage(engine) - - def _populate_external_lineage(self) -> None: - engine = self.get_metadata_engine(database=None) - # Handles the case where a table is populated from an external location via copy. - # Eg: copy into category_english from 's3://acryl-snow-demo-olist/olist_raw_data/category_english'credentials=(aws_key_id='...' aws_secret_key='...') pattern='.*.csv'; - query: str = """ - WITH external_table_lineage_history AS ( - SELECT - r.value:"locations" as upstream_locations, - w.value:"objectName" AS downstream_table_name, - w.value:"objectDomain" AS downstream_table_domain, - w.value:"columns" AS downstream_table_columns, - t.query_start_time AS query_start_time - FROM - (SELECT * from snowflake.account_usage.access_history) t, - lateral flatten(input => t.BASE_OBJECTS_ACCESSED) r, - lateral flatten(input => t.OBJECTS_MODIFIED) w - WHERE r.value:"locations" IS NOT NULL - AND w.value:"objectId" IS NOT NULL - AND t.query_start_time >= to_timestamp_ltz({start_time_millis}, 3) - AND t.query_start_time < to_timestamp_ltz({end_time_millis}, 3)) - SELECT upstream_locations, downstream_table_name, downstream_table_columns - FROM external_table_lineage_history - WHERE downstream_table_domain = 'Table' - QUALIFY ROW_NUMBER() OVER (PARTITION BY downstream_table_name ORDER BY query_start_time DESC) = 1""".format( - start_time_millis=int(self.config.start_time.timestamp() * 1000) - if not self.config.ignore_start_time_lineage - else 0, - end_time_millis=int(self.config.end_time.timestamp() * 1000), - ) - - num_edges: int = 0 - self._external_lineage_map = defaultdict(set) - try: - for db_row in engine.execute(query): - # key is the down-stream table name - key: str = db_row[1].lower().replace('"', "") - if not self._is_dataset_allowed(key): - continue - self._external_lineage_map[key] |= {*json.loads(db_row[0])} - logger.debug( - f"ExternalLineage[Table(Down)={key}]:External(Up)={self._external_lineage_map[key]} via access_history" - ) - except Exception as e: - logger.warning( - f"Populating table external lineage from Snowflake failed." - f"Please check your premissions. Continuing...\nError was {e}." - ) - # Handles the case for explicitly created external tables. - # NOTE: Snowflake does not log this information to the access_history table. - external_tables_query: str = "show external tables in account" - try: - for db_row in engine.execute(external_tables_query): - key = ( - f"{db_row.database_name}.{db_row.schema_name}.{db_row.name}".lower() - ) - if not self._is_dataset_allowed(dataset_name=key): - continue - self._external_lineage_map[key].add(db_row.location) - logger.debug( - f"ExternalLineage[Table(Down)={key}]:External(Up)={self._external_lineage_map[key]} via show external tables" - ) - num_edges += 1 - except Exception as e: - self.warn( - logger, - "external_lineage", - f"Populating external table lineage from Snowflake failed." - f"Please check your premissions. Continuing...\nError was {e}.", - ) - logger.info(f"Found {num_edges} external lineage edges.") - self.report.num_external_table_edges_scanned = num_edges - - def _populate_lineage(self) -> None: - engine = self.get_metadata_engine(database=None) - query: str = """ -WITH table_lineage_history AS ( - SELECT - r.value:"objectName" AS upstream_table_name, - r.value:"objectDomain" AS upstream_table_domain, - r.value:"columns" AS upstream_table_columns, - w.value:"objectName" AS downstream_table_name, - w.value:"objectDomain" AS downstream_table_domain, - w.value:"columns" AS downstream_table_columns, - t.query_start_time AS query_start_time - FROM - (SELECT * from snowflake.account_usage.access_history) t, - lateral flatten(input => t.DIRECT_OBJECTS_ACCESSED) r, - lateral flatten(input => t.OBJECTS_MODIFIED) w - WHERE r.value:"objectId" IS NOT NULL - AND w.value:"objectId" IS NOT NULL - AND w.value:"objectName" NOT LIKE '%.GE_TMP_%' - AND w.value:"objectName" NOT LIKE '%.GE_TEMP_%' - AND t.query_start_time >= to_timestamp_ltz({start_time_millis}, 3) - AND t.query_start_time < to_timestamp_ltz({end_time_millis}, 3)) -SELECT upstream_table_name, downstream_table_name, upstream_table_columns, downstream_table_columns -FROM table_lineage_history -WHERE upstream_table_domain in ('Table', 'External table') and downstream_table_domain = 'Table' -QUALIFY ROW_NUMBER() OVER (PARTITION BY downstream_table_name, upstream_table_name ORDER BY query_start_time DESC) = 1 """.format( - start_time_millis=int(self.config.start_time.timestamp() * 1000) - if not self.config.ignore_start_time_lineage - else 0, - end_time_millis=int(self.config.end_time.timestamp() * 1000), - ) - num_edges: int = 0 - self._lineage_map = defaultdict(list) - try: - for db_row in engine.execute(query): - # key is the down-stream table name - key: str = db_row[1].lower().replace('"', "") - upstream_table_name = db_row[0].lower().replace('"', "") - if not ( - self._is_dataset_allowed(key) - or self._is_dataset_allowed(upstream_table_name) - ): - continue - self._lineage_map[key].append( - # (, , ) - (upstream_table_name, db_row[2], db_row[3]) - ) - num_edges += 1 - logger.debug( - f"Lineage[Table(Down)={key}]:Table(Up)={self._lineage_map[key]}" - ) - except Exception as e: - self.warn( - logger, - "lineage", - f"Extracting lineage from Snowflake failed." - f"Please check your premissions. Continuing...\nError was {e}.", - ) - logger.info( - f"A total of {num_edges} Table->Table edges found" - f" for {len(self._lineage_map)} downstream tables.", - ) - self.report.num_table_to_table_edges_scanned = num_edges - - def _get_upstream_lineage_info( - self, dataset_urn: str - ) -> Optional[Tuple[UpstreamLineage, Dict[str, str]]]: - dataset_key = builder.dataset_urn_to_key(dataset_urn) - if dataset_key is None: - logger.warning(f"Invalid dataset urn {dataset_urn}. Could not get key!") - return None - - if self._lineage_map is None: - self._populate_lineage() - self._populate_view_lineage() - if self._external_lineage_map is None: - self._populate_external_lineage() - - assert self._lineage_map is not None - assert self._external_lineage_map is not None - dataset_name = dataset_key.name - lineage = self._lineage_map[dataset_name] - external_lineage = self._external_lineage_map[dataset_name] - if not (lineage or external_lineage): - logger.debug(f"No lineage found for {dataset_name}") - return None - upstream_tables: List[UpstreamClass] = [] - column_lineage: Dict[str, str] = {} - for lineage_entry in lineage: - # Update the table-lineage - upstream_table_name = lineage_entry[0] - if not self._is_dataset_allowed(upstream_table_name): - continue - upstream_table = UpstreamClass( - dataset=builder.make_dataset_urn_with_platform_instance( - self.platform, - upstream_table_name, - self.config.platform_instance, - self.config.env, - ), - type=DatasetLineageTypeClass.TRANSFORMED, - ) - upstream_tables.append(upstream_table) - # Update column-lineage for each down-stream column. - upstream_columns = [ - d["columnName"].lower() for d in json.loads(lineage_entry[1]) - ] - downstream_columns = [ - d["columnName"].lower() for d in json.loads(lineage_entry[2]) - ] - upstream_column_str = ( - f"{upstream_table_name}({', '.join(sorted(upstream_columns))})" - ) - downstream_column_str = ( - f"{dataset_name}({', '.join(sorted(downstream_columns))})" - ) - column_lineage_key = f"column_lineage[{upstream_table_name}]" - column_lineage_value = ( - f"{{{upstream_column_str} -> {downstream_column_str}}}" - ) - column_lineage[column_lineage_key] = column_lineage_value - logger.debug(f"{column_lineage_key}:{column_lineage_value}") - - for external_lineage_entry in external_lineage: - # For now, populate only for S3 - if external_lineage_entry.startswith("s3://"): - external_upstream_table = UpstreamClass( - dataset=make_s3_urn(external_lineage_entry, self.config.env), - type=DatasetLineageTypeClass.COPY, - ) - upstream_tables.append(external_upstream_table) - - if upstream_tables: - logger.debug( - f"Upstream lineage of '{dataset_name}': {[u.dataset for u in upstream_tables]}" - ) - if self.config.upstream_lineage_in_report: - self.report.upstream_lineage[dataset_name] = [ - u.dataset for u in upstream_tables - ] - return UpstreamLineage(upstreams=upstream_tables), column_lineage - return None - def add_config_to_report(self): self.report.cleaned_account_id = self.config.get_account() self.report.ignore_start_time_lineage = self.config.ignore_start_time_lineage @@ -866,7 +496,13 @@ def get_workunits(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit]]: dataset_snapshot: DatasetSnapshot = wu.metadata.proposedSnapshot assert dataset_snapshot # Join the workunit stream from super with the lineage info using the urn. - lineage_info = self._get_upstream_lineage_info(dataset_snapshot.urn) + + dataset_name = self.get_dataset_name_from_urn(dataset_snapshot.urn) + lineage_info = ( + self.lineage_extractor._get_upstream_lineage_info(dataset_name) + if dataset_name is not None + else None + ) if lineage_info is not None: # Emit the lineage work unit upstream_lineage, upstream_column_props = lineage_info @@ -911,26 +547,6 @@ def get_workunits(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit]]: # Emit the work unit from super. yield wu - def _is_dataset_allowed( - self, dataset_name: Optional[str], is_view: bool = False - ) -> bool: - # View lineages is not supported. Add the allow/deny pattern for that when it is supported. - if dataset_name is None: - return True - dataset_params = dataset_name.split(".") - if len(dataset_params) != 3: - return True - if ( - not self.config.database_pattern.allowed(dataset_params[0]) - or not self.config.schema_pattern.allowed(dataset_params[1]) - or ( - not is_view and not self.config.table_pattern.allowed(dataset_params[2]) - ) - or (is_view and not self.config.view_pattern.allowed(dataset_params[2])) - ): - return False - return True - def generate_profile_candidates( self, inspector: Inspector, threshold_time: Optional[datetime], schema: str ) -> Optional[List[str]]: @@ -974,3 +590,10 @@ def generate_profile_candidates( def get_platform_instance_id(self) -> str: """Overrides the source identifier for stateful ingestion.""" return self.config.get_account() + + def get_dataset_name_from_urn(self, dataset_urn: str) -> Optional[str]: + dataset_key = builder.dataset_urn_to_key(dataset_urn) + if dataset_key is None: + logger.warning(f"Invalid dataset urn {dataset_urn}. Could not get key!") + return None + return dataset_key.name