Skip to content

Commit

Permalink
Return stacktrace from the DAG file in `test_should_not_do_database_q…
Browse files Browse the repository at this point in the history
…ueries` (#39331)
  • Loading branch information
Taragolis authored Apr 30, 2024
1 parent 2872d37 commit 3204d64
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 14 deletions.
2 changes: 1 addition & 1 deletion tests/always/test_example_dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_should_be_importable(example):
@pytest.mark.db_test
@pytest.mark.parametrize("example", example_dags_except_db_exception(), ids=relative_path)
def test_should_not_do_database_queries(example):
with assert_queries_count(0):
with assert_queries_count(0, stacklevel_from_module=example.rsplit(os.sep, 1)[-1]):
DagBag(
dag_folder=example,
include_examples=False,
Expand Down
2 changes: 2 additions & 0 deletions tests/deprecations_ignore.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
- tests/always/test_connection.py::TestConnection::test_connection_get_uri_from_uri
- tests/always/test_connection.py::TestConnection::test_connection_test_success
- tests/always/test_connection.py::TestConnection::test_from_json_extra
# `test_should_be_importable` and `test_should_not_do_database_queries` should be resolved together
- tests/always/test_example_dags.py::test_should_be_importable
- tests/always/test_example_dags.py::test_should_not_do_database_queries


# API
Expand Down
78 changes: 65 additions & 13 deletions tests/test_utils/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
from __future__ import annotations

import logging
import os
import re
import traceback
from collections import Counter
from contextlib import contextmanager
from typing import NamedTuple

from sqlalchemy import event

Expand All @@ -40,6 +42,47 @@ def _trim(s):
assert first_trim == second_trim, msg


class QueriesTraceRecord(NamedTuple):
module: str
name: str
lineno: int | None

@classmethod
def from_frame(cls, frame_summary: traceback.FrameSummary):
return cls(
module=frame_summary.filename.rsplit(os.sep, 1)[-1],
name=frame_summary.name,
lineno=frame_summary.lineno,
)

def __str__(self):
return f"{self.module}:{self.name}:{self.lineno}"


class QueriesTraceInfo(NamedTuple):
traces: tuple[QueriesTraceRecord, ...]

@classmethod
def from_traceback(cls, trace: traceback.StackSummary) -> QueriesTraceInfo:
records = [
QueriesTraceRecord.from_frame(f)
for f in trace
if "sqlalchemy" not in f.filename
and __file__ != f.filename
and ("session.py" not in f.filename and f.name != "wrapper")
]
return cls(traces=tuple(records))

def module_level(self, module: str) -> int:
stacklevel = 0
for ix, record in enumerate(reversed(self.traces), start=1):
if record.module == module:
stacklevel = ix
if stacklevel == 0:
raise LookupError(f"Unable to find module {stacklevel} in traceback")
return stacklevel


class CountQueries:
"""
Counts the number of queries sent to Airflow Database in a given context.
Expand All @@ -48,8 +91,10 @@ class CountQueries:
not be included.
"""

def __init__(self):
self.result = Counter()
def __init__(self, *, stacklevel: int = 1, stacklevel_from_module: str | None = None):
self.result: Counter[str] = Counter()
self.stacklevel = stacklevel
self.stacklevel_from_module = stacklevel_from_module

def __enter__(self):
event.listen(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute)
Expand All @@ -60,31 +105,38 @@ def __exit__(self, type_, value, tb):
log.debug("Queries count: %d", sum(self.result.values()))

def after_cursor_execute(self, *args, **kwargs):
stack = [
f
for f in traceback.extract_stack()
if "sqlalchemy" not in f.filename
and __file__ != f.filename
and ("session.py" not in f.filename and f.name != "wrapper")
]
stack_info = ">".join([f"{f.filename.rpartition('/')[-1]}:{f.name}:{f.lineno}" for f in stack][-5:])
self.result[f"{stack_info}"] += 1
stack = QueriesTraceInfo.from_traceback(traceback.extract_stack())
if not self.stacklevel_from_module:
stacklevel = self.stacklevel
else:
stacklevel = stack.module_level(self.stacklevel_from_module)

stack_info = " > ".join(map(str, stack.traces[-stacklevel:]))
self.result[stack_info] += 1


count_queries = CountQueries


@contextmanager
def assert_queries_count(expected_count: int, message_fmt: str | None = None, margin: int = 0):
def assert_queries_count(
expected_count: int,
message_fmt: str | None = None,
margin: int = 0,
stacklevel: int = 5,
stacklevel_from_module: str | None = None,
):
"""
Asserts that the number of queries is as expected with the margin applied
The margin is helpful in case of complex cases where we do not want to change it every time we
changed queries, but we want to catch cases where we spin out of control
:param expected_count: expected number of queries
:param message_fmt: message printed optionally if the number is exceeded
:param margin: margin to add to expected number of calls
:param stacklevel: limits the output stack trace to that numbers of frame
:param stacklevel_from_module: Filter stack trace from specific module
"""
with count_queries() as result:
with count_queries(stacklevel=stacklevel, stacklevel_from_module=stacklevel_from_module) as result:
yield None

count = sum(result.values())
Expand Down

0 comments on commit 3204d64

Please sign in to comment.