Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Simplify code in tests #33293

Merged
merged 1 commit into from
Aug 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions tests/always/test_example_dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,15 @@ def example_not_suspended_dags():
for example_dir in example_dirs:
candidates = glob(f"{AIRFLOW_SOURCES_ROOT.as_posix()}/{example_dir}", recursive=True)
for candidate in candidates:
if any(candidate.startswith(s) for s in suspended_providers_folders):
continue
yield candidate
if not candidate.startswith(tuple(suspended_providers_folders)):
yield candidate


def example_dags_except_db_exception():
return [
dag_file
for dag_file in example_not_suspended_dags()
if not any(dag_file.endswith(e) for e in NO_DB_QUERY_EXCEPTION)
if not dag_file.endswith(tuple(NO_DB_QUERY_EXCEPTION))
]


Expand Down
4 changes: 2 additions & 2 deletions tests/always/test_project_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def get_classes_from_file(self, filepath: str):
if not isinstance(current_node, ast.ClassDef):
continue
name = current_node.name
if not any(name.endswith(suffix) for suffix in self.CLASS_SUFFIXES):
if not name.endswith(tuple(self.CLASS_SUFFIXES)):
continue
results[f"{module}.{name}"] = current_node
return results
Expand Down Expand Up @@ -463,6 +463,6 @@ def test_no_illegal_suffixes(self):
)
)

invalid_files = [f for f in files if any(f.endswith(suffix) for suffix in illegal_suffixes)]
invalid_files = [f for f in files if f.endswith(tuple(illegal_suffixes))]

assert [] == invalid_files
4 changes: 1 addition & 3 deletions tests/cli/commands/test_config_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,7 @@ def test_cli_comment_out_everything(self):
)
output = temp_stdout.getvalue()
lines = output.split("\n")
assert all(
line.startswith("#") or line.strip() == "" or line.startswith("[") for line in lines if line
)
assert all(not line.strip() or line.startswith(("#", "[")) for line in lines if line)


class TestCliConfigGetValue:
Expand Down
2 changes: 1 addition & 1 deletion tests/decorators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def add_num(number: int, num2: int = 2):
bigger_number.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
ti_add_num = [ti for ti in dr.get_task_instances() if ti.task_id == "add_num"][0]
ti_add_num = next(ti for ti in dr.get_task_instances() if ti.task_id == "add_num")
assert ti_add_num.xcom_pull(key=ret.key) == (test_number + 2) * 2

def test_dag_task(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3123,7 +3123,7 @@ def test_list_py_file_paths(self):
}
for root, _, files in os.walk(TEST_DAG_FOLDER):
for file_name in files:
if file_name.endswith(".py") or file_name.endswith(".zip"):
if file_name.endswith((".py", ".zip")):
if file_name not in ignored_files:
expected_files.add(f"{root}/{file_name}")
for file_path in list_py_file_paths(TEST_DAG_FOLDER, include_examples=False):
Expand All @@ -3136,7 +3136,7 @@ def test_list_py_file_paths(self):
example_dag_folder = airflow.example_dags.__path__[0]
for root, _, files in os.walk(example_dag_folder):
for file_name in files:
if file_name.endswith(".py") or file_name.endswith(".zip"):
if file_name.endswith((".py", ".zip")):
if file_name not in ["__init__.py"] and file_name not in ignored_files:
expected_files.add(os.path.join(root, file_name))
detected_files.clear()
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1706,7 +1706,7 @@ def task_2(arg2):
(1, State.NONE),
(2, State.NONE),
]
ti1 = [i for i in tis if i.map_index == 0][0]
ti1 = next(i for i in tis if i.map_index == 0)
# Now "clear" and "reduce" the length to empty list
dag.clear()
Variable.set(key="arg1", value=[])
Expand Down
8 changes: 4 additions & 4 deletions tests/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,7 +1396,7 @@ def test_short_circuit_with_teardowns(
op1.skip = MagicMock()
dagrun = dag_maker.create_dagrun()
tis = dagrun.get_task_instances()
ti: TaskInstance = [x for x in tis if x.task_id == "op1"][0]
ti: TaskInstance = next(x for x in tis if x.task_id == "op1")
ti._run_raw_task()
expected_tasks = {dag.task_dict[x] for x in expected}
if should_skip:
Expand Down Expand Up @@ -1427,7 +1427,7 @@ def test_short_circuit_with_teardowns_complicated(self, dag_maker, config):
op1.skip = MagicMock()
dagrun = dag_maker.create_dagrun()
tis = dagrun.get_task_instances()
ti: TaskInstance = [x for x in tis if x.task_id == "op1"][0]
ti: TaskInstance = next(x for x in tis if x.task_id == "op1")
ti._run_raw_task()
# we can't use assert_called_with because it's a set and therefore not ordered
actual_skipped = set(op1.skip.call_args.kwargs["tasks"])
Expand All @@ -1454,7 +1454,7 @@ def test_short_circuit_with_teardowns_complicated_2(self, dag_maker):
op1.skip = MagicMock()
dagrun = dag_maker.create_dagrun()
tis = dagrun.get_task_instances()
ti: TaskInstance = [x for x in tis if x.task_id == "op1"][0]
ti: TaskInstance = next(x for x in tis if x.task_id == "op1")
ti._run_raw_task()
# we can't use assert_called_with because it's a set and therefore not ordered
actual_kwargs = op1.skip.call_args.kwargs
Expand Down Expand Up @@ -1489,7 +1489,7 @@ def test_short_circuit_with_teardowns_debug_level(self, dag_maker, level, clear_
op1.skip = MagicMock()
dagrun = dag_maker.create_dagrun()
tis = dagrun.get_task_instances()
ti: TaskInstance = [x for x in tis if x.task_id == "op1"][0]
ti: TaskInstance = next(x for x in tis if x.task_id == "op1")
ti._run_raw_task()
# we can't use assert_called_with because it's a set and therefore not ordered
actual_kwargs = op1.skip.call_args.kwargs
Expand Down
2 changes: 1 addition & 1 deletion tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def sorted_serialized_dag(dag_dict: dict):
items should not matter but assertEqual would fail if the order of
items changes in the dag dictionary
"""
dag_dict["dag"]["tasks"] = sorted(dag_dict["dag"]["tasks"], key=lambda x: sorted(x.keys()))
dag_dict["dag"]["tasks"] = sorted(dag_dict["dag"]["tasks"], key=sorted)
dag_dict["dag"]["_access_control"]["__var"]["test_role"]["__var"] = sorted(
dag_dict["dag"]["_access_control"]["__var"]["test_role"]["__var"]
)
Expand Down
3 changes: 1 addition & 2 deletions tests/system/providers/amazon/aws/example_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def get_latest_ami_id():
Owners=["amazon"],
)
# Sort on CreationDate
sorted_images = sorted(images["Images"], key=itemgetter("CreationDate"), reverse=True)
return sorted_images[0]["ImageId"]
return max(images["Images"], key=itemgetter("CreationDate"))["ImageId"]


@task
Expand Down
2 changes: 1 addition & 1 deletion tests/system/providers/amazon/aws/utils/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _get_next_available_cidr(vpc_id: str) -> str:
if len({block.prefixlen for block in existing_cidr_blocks}) > 1:
raise ValueError(error_msg_template.format("Subnets do not all use the same CIDR block size."))

last_used_block = sorted(existing_cidr_blocks)[-1]
last_used_block = max(existing_cidr_blocks)
*_, last_reserved_ip = last_used_block
return f"{last_reserved_ip + 1}/{last_used_block.prefixlen}"

Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,6 @@ def get_provider_min_airflow_version(provider_name):

p = ProvidersManager()
deps = p.providers[provider_name].data["dependencies"]
airflow_dep = [x for x in deps if x.startswith("apache-airflow")][0]
airflow_dep = next(x for x in deps if x.startswith("apache-airflow"))
min_airflow_version = tuple(map(int, airflow_dep.split(">=")[1].split(".")))
return min_airflow_version
6 changes: 3 additions & 3 deletions tests/www/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_params_search(self):
def test_params_none_and_zero(self):
query_str = utils.get_params(a=0, b=None, c="true")
# The order won't be consistent, but that doesn't affect behaviour of a browser
pairs = list(sorted(query_str.split("&")))
pairs = sorted(query_str.split("&"))
assert ["a=0", "c=true"] == pairs

def test_params_all(self):
Expand Down Expand Up @@ -429,11 +429,11 @@ def test_dag_run_custom_sqla_interface_delete_no_collateral_damage(dag_maker, se
assert len(set(x.run_id for x in dag_runs)) == 3
run_id_for_single_delete = "scheduled__2023-01-01T00:00:00+00:00"
# we have 3 runs with this same run_id
assert len(list(x for x in dag_runs if x.run_id == run_id_for_single_delete)) == 3
assert sum(1 for x in dag_runs if x.run_id == run_id_for_single_delete) == 3
# each is a different dag

# if we delete one, it shouldn't delete the others
one_run = [x for x in dag_runs if x.run_id == run_id_for_single_delete][0]
one_run = next(x for x in dag_runs if x.run_id == run_id_for_single_delete)
assert interface.delete(item=one_run) is True
session.commit()
dag_runs = session.query(DagRun).all()
Expand Down
2 changes: 1 addition & 1 deletion tests/www/views/test_views_acl.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def client_all_dags_dagruns(acl_app, user_all_dags_dagruns):
def test_dag_stats_success(client_all_dags_dagruns):
resp = client_all_dags_dagruns.post("dag_stats", follow_redirects=True)
check_content_in_response("example_bash_operator", resp)
assert set(list(resp.json.items())[0][1][0].keys()) == {"state", "count"}
assert set(next(iter(resp.json.items()))[1][0].keys()) == {"state", "count"}


def test_task_stats_failure(dag_test_client):
Expand Down
6 changes: 3 additions & 3 deletions tests/www/views/test_views_home.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def working_dags(tmpdir):
dag_contents_template = "from airflow import DAG\ndag = DAG('{}', tags=['{}'])"

with create_session() as session:
for dag_id, tag in list(zip(TEST_FILTER_DAG_IDS, TEST_TAGS)):
for dag_id, tag in zip(TEST_FILTER_DAG_IDS, TEST_TAGS):
filename = os.path.join(tmpdir, f"{dag_id}.py")
with open(filename, "w") as f:
f.writelines(dag_contents_template.format(dag_id, tag))
Expand All @@ -169,7 +169,7 @@ def working_dags_with_read_perm(tmpdir):
"access_control={{'role_single_dag':{{'can_read'}}}}) "
)
with create_session() as session:
for dag_id, tag in list(zip(TEST_FILTER_DAG_IDS, TEST_TAGS)):
for dag_id, tag in zip(TEST_FILTER_DAG_IDS, TEST_TAGS):
filename = os.path.join(tmpdir, f"{dag_id}.py")
if dag_id == "filter_test_1":
with open(filename, "w") as f:
Expand All @@ -188,7 +188,7 @@ def working_dags_with_edit_perm(tmpdir):
"access_control={{'role_single_dag':{{'can_edit'}}}}) "
)
with create_session() as session:
for dag_id, tag in list(zip(TEST_FILTER_DAG_IDS, TEST_TAGS)):
for dag_id, tag in zip(TEST_FILTER_DAG_IDS, TEST_TAGS):
filename = os.path.join(tmpdir, f"{dag_id}.py")
if dag_id == "filter_test_1":
with open(filename, "w") as f:
Expand Down