Skip to content

Commit

Permalink
deduplicate queries, fix empty upstream lineage
Browse files Browse the repository at this point in the history
  • Loading branch information
mayurinehate committed Aug 7, 2024
1 parent 12fea30 commit cdfb94f
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
import tempfile
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Iterable, List, Optional, TypedDict, Union
from typing import Dict, Iterable, List, MutableMapping, Optional, TypedDict

from google.cloud.bigquery import Client
from pydantic import Field

from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.time_window_config import BaseTimeWindowConfig
from datahub.configuration.time_window_config import (
BaseTimeWindowConfig,
get_time_bucket,
)
from datahub.ingestion.api.report import Report
from datahub.ingestion.api.source import SourceReport
from datahub.ingestion.api.source_helpers import auto_workunit
Expand All @@ -32,13 +35,18 @@
from datahub.sql_parsing.schema_resolver import SchemaResolver
from datahub.sql_parsing.sql_parsing_aggregator import (
ObservedQuery,
PreparsedQuery,
SqlAggregatorReport,
SqlParsingAggregator,
)
from datahub.utilities.file_backed_collections import ConnectionWrapper, FileBackedList
from datahub.sql_parsing.sqlglot_utils import get_query_fingerprint
from datahub.utilities.file_backed_collections import (
ConnectionWrapper,
FileBackedDict,
FileBackedList,
)
from datahub.utilities.perf_timer import PerfTimer
from datahub.utilities.stats_collections import TopKDict, int_top_k_dict
from datahub.utilities.time import datetime_to_ts_millis

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -73,6 +81,7 @@ class BigQueryJob(TypedDict):
total_bytes_processed: int
dml_statistics: Optional[DMLJobStatistics]
session_id: Optional[str]
query_hash: Optional[str]


class BigQueryQueriesExtractorConfig(BigQueryBaseConfig):
Expand All @@ -98,14 +107,23 @@ class BigQueryQueriesExtractorConfig(BigQueryBaseConfig):
include_query_usage_statistics: bool = False
include_operations: bool = True

region_qualifiers: List[str] = Field(
default=["region-us", "region-eu"],
description="BigQuery regions to be scanned for bigquery jobs. See [this](https://cloud.google.com/bigquery/docs/information-schema-jobs) for details.",
)


@dataclass
class BigQueryQueriesExtractorReport(Report):
query_log_fetch_timer: PerfTimer = field(default_factory=PerfTimer)
audit_log_preprocessing_timer: PerfTimer = field(default_factory=PerfTimer)
audit_log_load_timer: PerfTimer = field(default_factory=PerfTimer)
sql_aggregator: Optional[SqlAggregatorReport] = None
num_queries_by_project: TopKDict[str, int] = field(default_factory=int_top_k_dict)

num_total_queries: int = 0
num_unique_queries: int = 0


class BigQueryQueriesExtractor:
def __init__(
Expand Down Expand Up @@ -197,7 +215,7 @@ def get_workunits_internal(
audit_log_file = self.local_temp_path / "audit_log.sqlite"
use_cached_audit_log = audit_log_file.exists()

queries: FileBackedList[Union[PreparsedQuery, ObservedQuery]]
queries: FileBackedList[ObservedQuery]
if use_cached_audit_log:
logger.info("Using cached audit log")
shared_connection = ConnectionWrapper(audit_log_file)
Expand All @@ -207,7 +225,7 @@ def get_workunits_internal(

shared_connection = ConnectionWrapper(audit_log_file)
queries = FileBackedList(shared_connection)
entry: Union[PreparsedQuery, ObservedQuery]
entry: ObservedQuery

with self.report.query_log_fetch_timer:
for project in get_projects(
Expand All @@ -216,22 +234,62 @@ def get_workunits_internal(
for entry in self.fetch_query_log(project):
self.report.num_queries_by_project[project.id] += 1
queries.append(entry)
self.report.num_total_queries = len(queries)

with self.report.audit_log_preprocessing_timer:
# Preprocessing stage that deduplicates the queries using query hash per usage bucket
queries_deduped: MutableMapping[str, Dict[int, ObservedQuery]]
queries_deduped = self.deduplicate_queries(queries)
self.report.num_unique_queries = len(queries_deduped)

with self.report.audit_log_load_timer:
for i, query in enumerate(queries):
if i % 1000 == 0:
logger.info(f"Added {i} query log entries to SQL aggregator")
self.aggregator.add(query)
i = 0
# Is FileBackedDict OrderedDict ? i.e. keys / values are retrieved in same order as added ?
# Does aggregator expect to see queries in same order as they were executed ?
for query_instances in queries_deduped.values():
for _, query in query_instances.items():
if i > 0 and i % 1000 == 0:
logger.info(f"Added {i} query log entries to SQL aggregator")

self.aggregator.add(query)
i += 1

yield from auto_workunit(self.aggregator.gen_metadata())

def fetch_query_log(
self, project: BigqueryProject
) -> Iterable[Union[PreparsedQuery, ObservedQuery]]:
def deduplicate_queries(
self, queries: FileBackedList[ObservedQuery]
) -> MutableMapping[str, Dict[int, ObservedQuery]]:
queries_deduped: FileBackedDict[Dict[int, ObservedQuery]] = FileBackedDict()
for query in queries:
time_bucket = (
datetime_to_ts_millis(
get_time_bucket(query.timestamp, self.config.window.bucket_duration)
)
if query.timestamp
else 0
)
query_hash = get_query_fingerprint(
query.query, self.identifiers.platform, fast=True
)
query.query_hash = query_hash
if query_hash not in queries_deduped:
queries_deduped[query_hash] = {time_bucket: query}
else:
seen_query = queries_deduped[query_hash]
if time_bucket not in seen_query:
seen_query[time_bucket] = query
else:
observed_query = seen_query[time_bucket]
observed_query.usage_multiplier += 1
observed_query.timestamp = query.timestamp
queries_deduped[query_hash] = seen_query

return queries_deduped

def fetch_query_log(self, project: BigqueryProject) -> Iterable[ObservedQuery]:

# Multi-regions from https://cloud.google.com/bigquery/docs/locations#supported_locations
regions = ["region-us", "region-eu"]
# TODO: support other regions as required - via a config
regions = self.config.region_qualifiers

for region in regions:
# Each region needs to be a different query
Expand Down Expand Up @@ -265,10 +323,7 @@ def fetch_query_log(
else:
yield entry

def _parse_audit_log_row(
self, row: BigQueryJob
) -> Union[ObservedQuery, PreparsedQuery]:

def _parse_audit_log_row(self, row: BigQueryJob) -> ObservedQuery:
timestamp: datetime = row["creation_time"]
timestamp = timestamp.astimezone(timezone.utc)

Expand All @@ -284,6 +339,8 @@ def _parse_audit_log_row(
),
default_db=row["project_id"],
default_schema=None,
# Not using BQ query hash as it's not always present
# query_hash=row["query_hash"],
)

return entry
Expand Down Expand Up @@ -319,12 +376,14 @@ def _build_enriched_query_log_query(
total_bytes_billed,
total_bytes_processed,
dml_statistics,
session_info.session_id as session_id
session_info.session_id as session_id,
query_info.query_hashes.normalized_literals as query_hash
FROM
`{project_id}`.`{region}`.INFORMATION_SCHEMA.JOBS
WHERE
creation_time >= '{audit_start_time}' AND
creation_time <= '{audit_end_time}' AND
error_result is null AND
not CONTAINS_SUBSTR(query, '.INFORMATION_SCHEMA.')
ORDER BY creation_time
"""
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ class LoggedQuery:
default_schema: Optional[str]


ObservedQuery = LoggedQuery
@dataclasses.dataclass
class ObservedQuery(LoggedQuery):
query_hash: Optional[str] = None
usage_multiplier: int = 1


@dataclasses.dataclass
Expand Down Expand Up @@ -489,8 +492,10 @@ def add(
default_db=item.default_db,
default_schema=item.default_schema,
session_id=item.session_id,
usage_multiplier=item.usage_multiplier,
query_timestamp=item.timestamp,
user=CorpUserUrn.from_string(item.user) if item.user else None,
query_hash=item.query_hash,
)
else:
raise ValueError(f"Cannot add unknown item type: {type(item)}")
Expand Down Expand Up @@ -634,6 +639,7 @@ def add_observed_query(
usage_multiplier: int = 1,
is_known_temp_table: bool = False,
require_out_table_schema: bool = False,
query_hash: Optional[str] = None,
) -> None:
"""Add an observed query to the aggregator.
Expand Down Expand Up @@ -677,8 +683,7 @@ def add_observed_query(
if isinstance(parsed.debug_info.column_error, CooperativeTimeoutError):
self.report.num_observed_queries_column_timeout += 1

query_fingerprint = parsed.query_fingerprint

query_fingerprint = query_hash or parsed.query_fingerprint
self.add_preparsed_query(
PreparsedQuery(
query_id=query_fingerprint,
Expand Down Expand Up @@ -1150,6 +1155,9 @@ def _gen_lineage_for_downstream(
upstream_aspect.fineGrainedLineages or None
)

if not upstream_aspect.upstreams and not upstream_aspect.fineGrainedLineages:
return

yield MetadataChangeProposalWrapper(
entityUrn=downstream_urn,
aspect=upstream_aspect,
Expand Down
25 changes: 25 additions & 0 deletions metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,3 +499,28 @@ def test_table_rename(pytestconfig: pytest.Config) -> None:
outputs=mcps,
golden_path=RESOURCE_DIR / "test_table_rename.json",
)


@freeze_time(FROZEN_TIME)
def test_create_table_query_mcps(pytestconfig: pytest.Config) -> None:
aggregator = SqlParsingAggregator(
platform="bigquery",
generate_lineage=True,
generate_usage_statistics=False,
generate_operations=True,
)

aggregator.add_observed_query(
query="create or replace table `dataset.foo` (date_utc timestamp, revenue int);",
default_db="dev",
default_schema="public",
query_timestamp=datetime.now(),
)

mcps = list(aggregator.gen_metadata())

mce_helpers.check_goldens_stream(
pytestconfig,
outputs=mcps,
golden_path=RESOURCE_DIR / "test_create_table_query_mcps.json",
)

0 comments on commit cdfb94f

Please sign in to comment.