From 3204d6492a5ea4da68b6dd4d944d5097002d21f8 Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Tue, 30 Apr 2024 22:30:34 +0400 Subject: [PATCH] Return stacktrace from the DAG file in `test_should_not_do_database_queries` (#39331) --- tests/always/test_example_dags.py | 2 +- tests/deprecations_ignore.yml | 2 + tests/test_utils/asserts.py | 78 +++++++++++++++++++++++++------ 3 files changed, 68 insertions(+), 14 deletions(-) diff --git a/tests/always/test_example_dags.py b/tests/always/test_example_dags.py index 9c6ef1be29d14..b52d755a5271f 100644 --- a/tests/always/test_example_dags.py +++ b/tests/always/test_example_dags.py @@ -117,7 +117,7 @@ def test_should_be_importable(example): @pytest.mark.db_test @pytest.mark.parametrize("example", example_dags_except_db_exception(), ids=relative_path) def test_should_not_do_database_queries(example): - with assert_queries_count(0): + with assert_queries_count(0, stacklevel_from_module=example.rsplit(os.sep, 1)[-1]): DagBag( dag_folder=example, include_examples=False, diff --git a/tests/deprecations_ignore.yml b/tests/deprecations_ignore.yml index d54ba980f8045..5e1432357ecea 100644 --- a/tests/deprecations_ignore.yml +++ b/tests/deprecations_ignore.yml @@ -34,7 +34,9 @@ - tests/always/test_connection.py::TestConnection::test_connection_get_uri_from_uri - tests/always/test_connection.py::TestConnection::test_connection_test_success - tests/always/test_connection.py::TestConnection::test_from_json_extra +# `test_should_be_importable` and `test_should_not_do_database_queries` should be resolved together - tests/always/test_example_dags.py::test_should_be_importable +- tests/always/test_example_dags.py::test_should_not_do_database_queries # API diff --git a/tests/test_utils/asserts.py b/tests/test_utils/asserts.py index fffbcaef2387b..d06bb454de4d6 100644 --- a/tests/test_utils/asserts.py +++ b/tests/test_utils/asserts.py @@ -17,10 +17,12 @@ from __future__ import annotations import logging +import os import re import traceback from collections import Counter from contextlib import contextmanager +from typing import NamedTuple from sqlalchemy import event @@ -40,6 +42,47 @@ def _trim(s): assert first_trim == second_trim, msg +class QueriesTraceRecord(NamedTuple): + module: str + name: str + lineno: int | None + + @classmethod + def from_frame(cls, frame_summary: traceback.FrameSummary): + return cls( + module=frame_summary.filename.rsplit(os.sep, 1)[-1], + name=frame_summary.name, + lineno=frame_summary.lineno, + ) + + def __str__(self): + return f"{self.module}:{self.name}:{self.lineno}" + + +class QueriesTraceInfo(NamedTuple): + traces: tuple[QueriesTraceRecord, ...] + + @classmethod + def from_traceback(cls, trace: traceback.StackSummary) -> QueriesTraceInfo: + records = [ + QueriesTraceRecord.from_frame(f) + for f in trace + if "sqlalchemy" not in f.filename + and __file__ != f.filename + and ("session.py" not in f.filename and f.name != "wrapper") + ] + return cls(traces=tuple(records)) + + def module_level(self, module: str) -> int: + stacklevel = 0 + for ix, record in enumerate(reversed(self.traces), start=1): + if record.module == module: + stacklevel = ix + if stacklevel == 0: + raise LookupError(f"Unable to find module {stacklevel} in traceback") + return stacklevel + + class CountQueries: """ Counts the number of queries sent to Airflow Database in a given context. @@ -48,8 +91,10 @@ class CountQueries: not be included. """ - def __init__(self): - self.result = Counter() + def __init__(self, *, stacklevel: int = 1, stacklevel_from_module: str | None = None): + self.result: Counter[str] = Counter() + self.stacklevel = stacklevel + self.stacklevel_from_module = stacklevel_from_module def __enter__(self): event.listen(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute) @@ -60,22 +105,27 @@ def __exit__(self, type_, value, tb): log.debug("Queries count: %d", sum(self.result.values())) def after_cursor_execute(self, *args, **kwargs): - stack = [ - f - for f in traceback.extract_stack() - if "sqlalchemy" not in f.filename - and __file__ != f.filename - and ("session.py" not in f.filename and f.name != "wrapper") - ] - stack_info = ">".join([f"{f.filename.rpartition('/')[-1]}:{f.name}:{f.lineno}" for f in stack][-5:]) - self.result[f"{stack_info}"] += 1 + stack = QueriesTraceInfo.from_traceback(traceback.extract_stack()) + if not self.stacklevel_from_module: + stacklevel = self.stacklevel + else: + stacklevel = stack.module_level(self.stacklevel_from_module) + + stack_info = " > ".join(map(str, stack.traces[-stacklevel:])) + self.result[stack_info] += 1 count_queries = CountQueries @contextmanager -def assert_queries_count(expected_count: int, message_fmt: str | None = None, margin: int = 0): +def assert_queries_count( + expected_count: int, + message_fmt: str | None = None, + margin: int = 0, + stacklevel: int = 5, + stacklevel_from_module: str | None = None, +): """ Asserts that the number of queries is as expected with the margin applied The margin is helpful in case of complex cases where we do not want to change it every time we @@ -83,8 +133,10 @@ def assert_queries_count(expected_count: int, message_fmt: str | None = None, ma :param expected_count: expected number of queries :param message_fmt: message printed optionally if the number is exceeded :param margin: margin to add to expected number of calls + :param stacklevel: limits the output stack trace to that numbers of frame + :param stacklevel_from_module: Filter stack trace from specific module """ - with count_queries() as result: + with count_queries(stacklevel=stacklevel, stacklevel_from_module=stacklevel_from_module) as result: yield None count = sum(result.values())