Skip to content

Commit

Permalink
Refactor: Consolidate import pytest (#34190)
Browse files Browse the repository at this point in the history
  • Loading branch information
eumiro authored Sep 11, 2023
1 parent a1fe77b commit 48930bc
Show file tree
Hide file tree
Showing 33 changed files with 141 additions and 160 deletions.
9 changes: 4 additions & 5 deletions kubernetes_tests/test_kubernetes_pod_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from kubernetes.client import V1EnvVar, V1PodSecurityContext, V1SecurityContext, models as k8s
from kubernetes.client.api_client import ApiClient
from kubernetes.client.rest import ApiException
from pytest import param

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models import DAG, Connection, DagRun, TaskInstance
Expand Down Expand Up @@ -395,7 +394,7 @@ def test_pod_resources(self, mock_get_connection):
@pytest.mark.parametrize(
"val",
[
param(
pytest.param(
k8s.V1Affinity(
node_affinity=k8s.V1NodeAffinity(
required_during_scheduling_ignored_during_execution=k8s.V1NodeSelector(
Expand All @@ -415,7 +414,7 @@ def test_pod_resources(self, mock_get_connection):
),
id="current",
),
param(
pytest.param(
{
"nodeAffinity": {
"requiredDuringSchedulingIgnoredDuringExecution": {
Expand Down Expand Up @@ -729,8 +728,8 @@ def test_pod_template_file_system(self, mock_get_connection):
@pytest.mark.parametrize(
"env_vars",
[
param([k8s.V1EnvVar(name="env_name", value="value")], id="current"),
param({"env_name": "value"}, id="backcompat"), # todo: remove?
pytest.param([k8s.V1EnvVar(name="env_name", value="value")], id="current"),
pytest.param({"env_name": "value"}, id="backcompat"), # todo: remove?
],
)
def test_pod_template_file_with_overrides_system(self, env_vars, test_label, mock_get_connection):
Expand Down
9 changes: 4 additions & 5 deletions tests/cli/commands/test_db_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import pendulum
import pytest
from pytest import param
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import OperationalError

Expand Down Expand Up @@ -102,18 +101,18 @@ def test_cli_upgrade_success(self, mock_upgradedb, args, called_with):
@pytest.mark.parametrize(
"args, pattern",
[
param(["--to-version", "2.1.25"], "not supported", id="bad version"),
param(
pytest.param(["--to-version", "2.1.25"], "not supported", id="bad version"),
pytest.param(
["--to-revision", "abc", "--from-revision", "abc123"],
"used with `--show-sql-only`",
id="requires offline",
),
param(
pytest.param(
["--to-revision", "abc", "--from-version", "2.0.2"],
"used with `--show-sql-only`",
id="requires offline",
),
param(
pytest.param(
["--to-revision", "abc", "--from-version", "2.1.25", "--show-sql-only"],
"Unknown version",
id="bad version",
Expand Down
5 changes: 2 additions & 3 deletions tests/core/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from unittest.mock import patch

import pytest
from pytest import param

from airflow import configuration
from airflow.configuration import (
Expand Down Expand Up @@ -1482,8 +1481,8 @@ def test_suppress_future_warnings_no_future_warning(self):
@pytest.mark.parametrize(
"key",
[
param("deactivate_stale_dags_interval", id="old"),
param("parsing_cleanup_interval", id="new"),
pytest.param("deactivate_stale_dags_interval", id="old"),
pytest.param("parsing_cleanup_interval", id="new"),
],
)
def test_future_warning_only_for_code_ref(self, key):
Expand Down
11 changes: 8 additions & 3 deletions tests/core/test_otel_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import pytest
from opentelemetry.metrics import MeterProvider
from pytest import param

from airflow.exceptions import InvalidStatsNameException
from airflow.metrics.otel_logger import (
Expand Down Expand Up @@ -76,8 +75,14 @@ def test_exemption_list_has_not_grown(self):
@pytest.mark.parametrize(
"invalid_stat_combo",
[
*[param(("prefix", name), id=f"Stat name {msg}.") for (name, msg) in INVALID_STAT_NAME_CASES],
*[param((prefix, "name"), id=f"Stat prefix {msg}.") for (prefix, msg) in INVALID_STAT_NAME_CASES],
*[
pytest.param(("prefix", name), id=f"Stat name {msg}.")
for (name, msg) in INVALID_STAT_NAME_CASES
],
*[
pytest.param((prefix, "name"), id=f"Stat prefix {msg}.")
for (prefix, msg) in INVALID_STAT_NAME_CASES
],
],
)
def test_invalid_stat_names_are_caught(self, invalid_stat_combo):
Expand Down
3 changes: 1 addition & 2 deletions tests/executors/test_base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import pendulum
import pytest
import time_machine
from pytest import mark

from airflow.executors.base_executor import BaseExecutor, RunningRetryAttemptType
from airflow.models.baseoperator import BaseOperator
Expand Down Expand Up @@ -132,7 +131,7 @@ def setup_trigger_tasks(dag_maker):
return executor, dagrun


@mark.parametrize("open_slots", [1, 2, 3])
@pytest.mark.parametrize("open_slots", [1, 2, 3])
def test_trigger_queued_tasks(dag_maker, open_slots):
executor, _ = setup_trigger_tasks(dag_maker)
executor.trigger_tasks(open_slots)
Expand Down
14 changes: 7 additions & 7 deletions tests/hooks/test_package_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Test for Package Index Hook."""
from __future__ import annotations

from pytest import FixtureRequest, MonkeyPatch, fixture, mark, raises
import pytest

from airflow.hooks.package_index import PackageIndexHook
from airflow.models.connection import Connection
Expand Down Expand Up @@ -60,11 +60,11 @@ def __init__(self, host: str | None, login: str | None, password: str | None):
}


@fixture(
@pytest.fixture(
params=list(PI_MOCK_TESTDATA.values()),
ids=list(PI_MOCK_TESTDATA.keys()),
)
def mock_get_connection(monkeypatch: MonkeyPatch, request: FixtureRequest) -> str | None:
def mock_get_connection(monkeypatch: pytest.MonkeyPatch, request: pytest.FixtureRequest) -> str | None:
"""Pytest Fixture."""
testdata: dict[str, str | None] = request.param
host: str | None = testdata.get("host", None)
Expand All @@ -86,12 +86,12 @@ def test_get_connection_url(mock_get_connection: str | None):
connection_url = hook_instance.get_connection_url()
assert connection_url == expected_result
else:
with raises(Exception):
with pytest.raises(Exception):
hook_instance.get_connection_url()


@mark.parametrize("success", [0, 1])
def test_test_connection(monkeypatch: MonkeyPatch, mock_get_connection: str | None, success: int):
@pytest.mark.parametrize("success", [0, 1])
def test_test_connection(monkeypatch: pytest.MonkeyPatch, mock_get_connection: str | None, success: int):
"""Test if connection test responds correctly to return code."""

def mock_run(*_, **__):
Expand All @@ -110,7 +110,7 @@ class MockProc:
result = hook_instance.test_connection()
assert result[0] == (success == 0)
else:
with raises(Exception):
with pytest.raises(Exception):
hook_instance.test_connection()


Expand Down
2 changes: 1 addition & 1 deletion tests/listeners/test_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import contextlib
import os

import pytest as pytest
import pytest

from airflow import AirflowException
from airflow.jobs.job import Job, run_job
Expand Down
11 changes: 5 additions & 6 deletions tests/models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from __future__ import annotations

import pytest
from pytest import param

from airflow.models.base import get_id_collation_args
from tests.test_utils.config import conf_vars
Expand All @@ -26,16 +25,16 @@
@pytest.mark.parametrize(
("dsn", "expected", "extra"),
[
param("postgresql://host/the_database", {}, {}, id="postgres"),
param("mysql://host/the_database", {"collation": "utf8mb3_bin"}, {}, id="mysql"),
param("mysql+pymsql://host/the_database", {"collation": "utf8mb3_bin"}, {}, id="mysql+pymsql"),
param(
pytest.param("postgresql://host/the_database", {}, {}, id="postgres"),
pytest.param("mysql://host/the_database", {"collation": "utf8mb3_bin"}, {}, id="mysql"),
pytest.param("mysql+pymsql://host/the_database", {"collation": "utf8mb3_bin"}, {}, id="mysql+pymsql"),
pytest.param(
"mysql://host/the_database",
{"collation": "ascii"},
{("database", "sql_engine_collation_for_ids"): "ascii"},
id="mysql with explicit config",
),
param(
pytest.param(
"postgresql://host/the_database",
{"collation": "ascii"},
{("database", "sql_engine_collation_for_ids"): "ascii"},
Expand Down
33 changes: 16 additions & 17 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import pendulum
import pytest
import time_machine
from pytest import param

from airflow import models, settings
from airflow.decorators import task, task_group
Expand Down Expand Up @@ -1165,7 +1164,7 @@ def test_depends_on_past(self, dag_maker):
# expect_passed
# states: success, skipped, failed, upstream_failed, removed, done, success_setup, skipped_setup
# all setups succeeded - one
param(
pytest.param(
"all_done_setup_success",
1,
_UpstreamTIStates(6, 0, 0, 0, 0, 6, 1, 0),
Expand All @@ -1174,7 +1173,7 @@ def test_depends_on_past(self, dag_maker):
True,
id="all setups succeeded - one",
),
param(
pytest.param(
"all_done_setup_success",
2,
_UpstreamTIStates(7, 0, 0, 0, 0, 7, 2, 0),
Expand All @@ -1183,7 +1182,7 @@ def test_depends_on_past(self, dag_maker):
True,
id="all setups succeeded - two",
),
param(
pytest.param(
"all_done_setup_success",
1,
_UpstreamTIStates(5, 0, 1, 0, 0, 6, 0, 0),
Expand All @@ -1192,7 +1191,7 @@ def test_depends_on_past(self, dag_maker):
False,
id="setups failed - one",
),
param(
pytest.param(
"all_done_setup_success",
2,
_UpstreamTIStates(5, 0, 2, 0, 0, 7, 0, 0),
Expand All @@ -1201,7 +1200,7 @@ def test_depends_on_past(self, dag_maker):
False,
id="setups failed - two",
),
param(
pytest.param(
"all_done_setup_success",
1,
_UpstreamTIStates(5, 1, 0, 0, 0, 6, 0, 1),
Expand All @@ -1210,7 +1209,7 @@ def test_depends_on_past(self, dag_maker):
False,
id="setups skipped - one",
),
param(
pytest.param(
"all_done_setup_success",
2,
_UpstreamTIStates(5, 2, 0, 0, 0, 7, 0, 2),
Expand All @@ -1219,7 +1218,7 @@ def test_depends_on_past(self, dag_maker):
False,
id="setups skipped - two",
),
param(
pytest.param(
"all_done_setup_success",
2,
_UpstreamTIStates(5, 1, 1, 0, 0, 7, 0, 1),
Expand All @@ -1228,7 +1227,7 @@ def test_depends_on_past(self, dag_maker):
False,
id="one setup failed one setup skipped",
),
param(
pytest.param(
"all_done_setup_success",
2,
_UpstreamTIStates(6, 0, 1, 0, 0, 7, 1, 0),
Expand All @@ -1237,7 +1236,7 @@ def test_depends_on_past(self, dag_maker):
True,
id="is teardown one setup failed one setup success",
),
param(
pytest.param(
"all_done_setup_success",
2,
_UpstreamTIStates(6, 0, 1, 0, 0, 7, 1, 0),
Expand All @@ -1246,7 +1245,7 @@ def test_depends_on_past(self, dag_maker):
True,
id="not teardown one setup failed one setup success",
),
param(
pytest.param(
"all_done_setup_success",
2,
_UpstreamTIStates(6, 1, 0, 0, 0, 7, 1, 1),
Expand All @@ -1255,7 +1254,7 @@ def test_depends_on_past(self, dag_maker):
True,
id="is teardown one setup success one setup skipped",
),
param(
pytest.param(
"all_done_setup_success",
2,
_UpstreamTIStates(6, 1, 0, 0, 0, 7, 1, 1),
Expand All @@ -1264,7 +1263,7 @@ def test_depends_on_past(self, dag_maker):
True,
id="not teardown one setup success one setup skipped",
),
param(
pytest.param(
"all_done_setup_success",
1,
_UpstreamTIStates(3, 0, 0, 0, 0, 3, 1, 0),
Expand All @@ -1273,7 +1272,7 @@ def test_depends_on_past(self, dag_maker):
False,
id="not all done",
),
param(
pytest.param(
"all_done_setup_success",
1,
_UpstreamTIStates(3, 0, 1, 0, 0, 4, 1, 0),
Expand All @@ -1282,7 +1281,7 @@ def test_depends_on_past(self, dag_maker):
False,
id="is teardown not all done one failed",
),
param(
pytest.param(
"all_done_setup_success",
1,
_UpstreamTIStates(3, 0, 1, 0, 0, 4, 1, 0),
Expand All @@ -1291,7 +1290,7 @@ def test_depends_on_past(self, dag_maker):
False,
id="not teardown not all done one failed",
),
param(
pytest.param(
"all_done_setup_success",
1,
_UpstreamTIStates(3, 1, 0, 0, 0, 4, 1, 0),
Expand All @@ -1300,7 +1299,7 @@ def test_depends_on_past(self, dag_maker):
False,
id="not all done one skipped",
),
param(
pytest.param(
"all_done_setup_success",
1,
_UpstreamTIStates(3, 1, 0, 0, 0, 4, 1, 0),
Expand Down
3 changes: 1 addition & 2 deletions tests/providers/amazon/aws/hooks/test_eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@
from datetime import datetime

from moto.core.exceptions import AWSError
from pytest import ExceptionInfo


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -1332,7 +1331,7 @@ def assert_all_arn_values_are_valid(expected_arn_values, pattern, arn_under_test


def assert_client_error_exception_thrown(
expected_exception: type[AWSError], expected_msg: str, raised_exception: ExceptionInfo
expected_exception: type[AWSError], expected_msg: str, raised_exception: pytest.ExceptionInfo
) -> None:
"""
Asserts that the raised exception is of the expected type
Expand Down
Loading

0 comments on commit 48930bc

Please sign in to comment.