diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_tag.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_tag.py index 9307eb607be261..be449e963d270b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_tag.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_tag.py @@ -169,5 +169,6 @@ def _filter_tags( self.report.report_entity_scanned(tag_identifier, "tag") if not self.config.tag_pattern.allowed(tag_identifier): self.report.report_dropped(tag_identifier) - allowed_tags.append(tag) + else: + allowed_tags.append(tag) return allowed_tags diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake_tag.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake_tag.py new file mode 100644 index 00000000000000..d5e265e7838825 --- /dev/null +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake_tag.py @@ -0,0 +1,102 @@ +from unittest import mock + +from datahub.configuration.common import AllowDenyPattern, DynamicTypedConfig +from datahub.ingestion.run.pipeline import Pipeline +from datahub.ingestion.run.pipeline_config import PipelineConfig, SourceConfig +from datahub.ingestion.source.snowflake.snowflake_config import ( + SnowflakeV2Config, + TagOption, +) +from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report +from tests.integration.snowflake.common import default_query_results + + +def test_snowflake_tag_pattern(): + with mock.patch("snowflake.connector.connect") as mock_connect: + sf_connection = mock.MagicMock() + sf_cursor = mock.MagicMock() + mock_connect.return_value = sf_connection + sf_connection.cursor.return_value = sf_cursor + sf_cursor.execute.side_effect = default_query_results + + tag_config = SnowflakeV2Config( + account_id="ABC12345.ap-south-1.aws", + username="TST_USR", + password="TST_PWD", + match_fully_qualified_names=True, + schema_pattern=AllowDenyPattern(allow=["test_db.test_schema"]), + tag_pattern=AllowDenyPattern( + allow=["TEST_DB.TEST_SCHEMA.my_tag_1:my_value_1"] + ), + include_technical_schema=True, + include_table_lineage=False, + include_view_lineage=False, + include_column_lineage=False, + include_usage_stats=False, + include_operational_stats=False, + extract_tags=TagOption.without_lineage, + ) + + pipeline = Pipeline( + config=PipelineConfig( + source=SourceConfig(type="snowflake", config=tag_config), + sink=DynamicTypedConfig(type="blackhole", config={}), + ) + ) + pipeline.run() + pipeline.pretty_print_summary() + pipeline.raise_from_status() + + source_report = pipeline.source.get_report() + assert isinstance(source_report, SnowflakeV2Report) + assert source_report.tags_scanned == 5 + assert source_report._processed_tags == { + "TEST_DB.TEST_SCHEMA.my_tag_1:my_value_1" + } + + +def test_snowflake_tag_pattern_deny(): + with mock.patch("snowflake.connector.connect") as mock_connect: + sf_connection = mock.MagicMock() + sf_cursor = mock.MagicMock() + mock_connect.return_value = sf_connection + sf_connection.cursor.return_value = sf_cursor + sf_cursor.execute.side_effect = default_query_results + + tag_config = SnowflakeV2Config( + account_id="ABC12345.ap-south-1.aws", + username="TST_USR", + password="TST_PWD", + match_fully_qualified_names=True, + schema_pattern=AllowDenyPattern(allow=["test_db.test_schema"]), + tag_pattern=AllowDenyPattern( + deny=["TEST_DB.TEST_SCHEMA.my_tag_2:my_value_2"] + ), + include_technical_schema=True, + include_table_lineage=False, + include_view_lineage=False, + include_column_lineage=False, + include_usage_stats=False, + include_operational_stats=False, + extract_tags=TagOption.without_lineage, + ) + + pipeline = Pipeline( + config=PipelineConfig( + source=SourceConfig(type="snowflake", config=tag_config), + sink=DynamicTypedConfig(type="blackhole", config={}), + ) + ) + pipeline.run() + pipeline.pretty_print_summary() + pipeline.raise_from_status() + + source_report = pipeline.source.get_report() + assert isinstance(source_report, SnowflakeV2Report) + assert source_report.tags_scanned == 5 + assert source_report._processed_tags == { + "OTHER_DB.OTHER_SCHEMA.my_other_tag:other", + "TEST_DB.TEST_SCHEMA.my_tag_0:my_value_0", + "TEST_DB.TEST_SCHEMA.my_tag_1:my_value_1", + "TEST_DB.TEST_SCHEMA.security:pii", + }