Skip to content

Commit

Permalink
Refactor: Simplify code in tests (#33293)
Browse files Browse the repository at this point in the history
  • Loading branch information
eumiro authored Aug 12, 2023
1 parent 8279628 commit d2c0bbe
Show file tree
Hide file tree
Showing 14 changed files with 25 additions and 29 deletions.
7 changes: 3 additions & 4 deletions tests/always/test_example_dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,15 @@ def example_not_suspended_dags():
for example_dir in example_dirs:
candidates = glob(f"{AIRFLOW_SOURCES_ROOT.as_posix()}/{example_dir}", recursive=True)
for candidate in candidates:
if any(candidate.startswith(s) for s in suspended_providers_folders):
continue
yield candidate
if not candidate.startswith(tuple(suspended_providers_folders)):
yield candidate


def example_dags_except_db_exception():
return [
dag_file
for dag_file in example_not_suspended_dags()
if not any(dag_file.endswith(e) for e in NO_DB_QUERY_EXCEPTION)
if not dag_file.endswith(tuple(NO_DB_QUERY_EXCEPTION))
]


Expand Down
4 changes: 2 additions & 2 deletions tests/always/test_project_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def get_classes_from_file(self, filepath: str):
if not isinstance(current_node, ast.ClassDef):
continue
name = current_node.name
if not any(name.endswith(suffix) for suffix in self.CLASS_SUFFIXES):
if not name.endswith(tuple(self.CLASS_SUFFIXES)):
continue
results[f"{module}.{name}"] = current_node
return results
Expand Down Expand Up @@ -463,6 +463,6 @@ def test_no_illegal_suffixes(self):
)
)

invalid_files = [f for f in files if any(f.endswith(suffix) for suffix in illegal_suffixes)]
invalid_files = [f for f in files if f.endswith(tuple(illegal_suffixes))]

assert [] == invalid_files
4 changes: 1 addition & 3 deletions tests/cli/commands/test_config_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,7 @@ def test_cli_comment_out_everything(self):
)
output = temp_stdout.getvalue()
lines = output.split("\n")
assert all(
line.startswith("#") or line.strip() == "" or line.startswith("[") for line in lines if line
)
assert all(not line.strip() or line.startswith(("#", "[")) for line in lines if line)


class TestCliConfigGetValue:
Expand Down
2 changes: 1 addition & 1 deletion tests/decorators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def add_num(number: int, num2: int = 2):
bigger_number.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
ti_add_num = [ti for ti in dr.get_task_instances() if ti.task_id == "add_num"][0]
ti_add_num = next(ti for ti in dr.get_task_instances() if ti.task_id == "add_num")
assert ti_add_num.xcom_pull(key=ret.key) == (test_number + 2) * 2

def test_dag_task(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3123,7 +3123,7 @@ def test_list_py_file_paths(self):
}
for root, _, files in os.walk(TEST_DAG_FOLDER):
for file_name in files:
if file_name.endswith(".py") or file_name.endswith(".zip"):
if file_name.endswith((".py", ".zip")):
if file_name not in ignored_files:
expected_files.add(f"{root}/{file_name}")
for file_path in list_py_file_paths(TEST_DAG_FOLDER, include_examples=False):
Expand All @@ -3136,7 +3136,7 @@ def test_list_py_file_paths(self):
example_dag_folder = airflow.example_dags.__path__[0]
for root, _, files in os.walk(example_dag_folder):
for file_name in files:
if file_name.endswith(".py") or file_name.endswith(".zip"):
if file_name.endswith((".py", ".zip")):
if file_name not in ["__init__.py"] and file_name not in ignored_files:
expected_files.add(os.path.join(root, file_name))
detected_files.clear()
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1706,7 +1706,7 @@ def task_2(arg2):
(1, State.NONE),
(2, State.NONE),
]
ti1 = [i for i in tis if i.map_index == 0][0]
ti1 = next(i for i in tis if i.map_index == 0)
# Now "clear" and "reduce" the length to empty list
dag.clear()
Variable.set(key="arg1", value=[])
Expand Down
8 changes: 4 additions & 4 deletions tests/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,7 +1396,7 @@ def test_short_circuit_with_teardowns(
op1.skip = MagicMock()
dagrun = dag_maker.create_dagrun()
tis = dagrun.get_task_instances()
ti: TaskInstance = [x for x in tis if x.task_id == "op1"][0]
ti: TaskInstance = next(x for x in tis if x.task_id == "op1")
ti._run_raw_task()
expected_tasks = {dag.task_dict[x] for x in expected}
if should_skip:
Expand Down Expand Up @@ -1427,7 +1427,7 @@ def test_short_circuit_with_teardowns_complicated(self, dag_maker, config):
op1.skip = MagicMock()
dagrun = dag_maker.create_dagrun()
tis = dagrun.get_task_instances()
ti: TaskInstance = [x for x in tis if x.task_id == "op1"][0]
ti: TaskInstance = next(x for x in tis if x.task_id == "op1")
ti._run_raw_task()
# we can't use assert_called_with because it's a set and therefore not ordered
actual_skipped = set(op1.skip.call_args.kwargs["tasks"])
Expand All @@ -1454,7 +1454,7 @@ def test_short_circuit_with_teardowns_complicated_2(self, dag_maker):
op1.skip = MagicMock()
dagrun = dag_maker.create_dagrun()
tis = dagrun.get_task_instances()
ti: TaskInstance = [x for x in tis if x.task_id == "op1"][0]
ti: TaskInstance = next(x for x in tis if x.task_id == "op1")
ti._run_raw_task()
# we can't use assert_called_with because it's a set and therefore not ordered
actual_kwargs = op1.skip.call_args.kwargs
Expand Down Expand Up @@ -1489,7 +1489,7 @@ def test_short_circuit_with_teardowns_debug_level(self, dag_maker, level, clear_
op1.skip = MagicMock()
dagrun = dag_maker.create_dagrun()
tis = dagrun.get_task_instances()
ti: TaskInstance = [x for x in tis if x.task_id == "op1"][0]
ti: TaskInstance = next(x for x in tis if x.task_id == "op1")
ti._run_raw_task()
# we can't use assert_called_with because it's a set and therefore not ordered
actual_kwargs = op1.skip.call_args.kwargs
Expand Down
2 changes: 1 addition & 1 deletion tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def sorted_serialized_dag(dag_dict: dict):
items should not matter but assertEqual would fail if the order of
items changes in the dag dictionary
"""
dag_dict["dag"]["tasks"] = sorted(dag_dict["dag"]["tasks"], key=lambda x: sorted(x.keys()))
dag_dict["dag"]["tasks"] = sorted(dag_dict["dag"]["tasks"], key=sorted)
dag_dict["dag"]["_access_control"]["__var"]["test_role"]["__var"] = sorted(
dag_dict["dag"]["_access_control"]["__var"]["test_role"]["__var"]
)
Expand Down
3 changes: 1 addition & 2 deletions tests/system/providers/amazon/aws/example_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def get_latest_ami_id():
Owners=["amazon"],
)
# Sort on CreationDate
sorted_images = sorted(images["Images"], key=itemgetter("CreationDate"), reverse=True)
return sorted_images[0]["ImageId"]
return max(images["Images"], key=itemgetter("CreationDate"))["ImageId"]


@task
Expand Down
2 changes: 1 addition & 1 deletion tests/system/providers/amazon/aws/utils/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _get_next_available_cidr(vpc_id: str) -> str:
if len({block.prefixlen for block in existing_cidr_blocks}) > 1:
raise ValueError(error_msg_template.format("Subnets do not all use the same CIDR block size."))

last_used_block = sorted(existing_cidr_blocks)[-1]
last_used_block = max(existing_cidr_blocks)
*_, last_reserved_ip = last_used_block
return f"{last_reserved_ip + 1}/{last_used_block.prefixlen}"

Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,6 @@ def get_provider_min_airflow_version(provider_name):

p = ProvidersManager()
deps = p.providers[provider_name].data["dependencies"]
airflow_dep = [x for x in deps if x.startswith("apache-airflow")][0]
airflow_dep = next(x for x in deps if x.startswith("apache-airflow"))
min_airflow_version = tuple(map(int, airflow_dep.split(">=")[1].split(".")))
return min_airflow_version
6 changes: 3 additions & 3 deletions tests/www/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_params_search(self):
def test_params_none_and_zero(self):
query_str = utils.get_params(a=0, b=None, c="true")
# The order won't be consistent, but that doesn't affect behaviour of a browser
pairs = list(sorted(query_str.split("&")))
pairs = sorted(query_str.split("&"))
assert ["a=0", "c=true"] == pairs

def test_params_all(self):
Expand Down Expand Up @@ -429,11 +429,11 @@ def test_dag_run_custom_sqla_interface_delete_no_collateral_damage(dag_maker, se
assert len(set(x.run_id for x in dag_runs)) == 3
run_id_for_single_delete = "scheduled__2023-01-01T00:00:00+00:00"
# we have 3 runs with this same run_id
assert len(list(x for x in dag_runs if x.run_id == run_id_for_single_delete)) == 3
assert sum(1 for x in dag_runs if x.run_id == run_id_for_single_delete) == 3
# each is a different dag

# if we delete one, it shouldn't delete the others
one_run = [x for x in dag_runs if x.run_id == run_id_for_single_delete][0]
one_run = next(x for x in dag_runs if x.run_id == run_id_for_single_delete)
assert interface.delete(item=one_run) is True
session.commit()
dag_runs = session.query(DagRun).all()
Expand Down
2 changes: 1 addition & 1 deletion tests/www/views/test_views_acl.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def client_all_dags_dagruns(acl_app, user_all_dags_dagruns):
def test_dag_stats_success(client_all_dags_dagruns):
resp = client_all_dags_dagruns.post("dag_stats", follow_redirects=True)
check_content_in_response("example_bash_operator", resp)
assert set(list(resp.json.items())[0][1][0].keys()) == {"state", "count"}
assert set(next(iter(resp.json.items()))[1][0].keys()) == {"state", "count"}


def test_task_stats_failure(dag_test_client):
Expand Down
6 changes: 3 additions & 3 deletions tests/www/views/test_views_home.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def working_dags(tmpdir):
dag_contents_template = "from airflow import DAG\ndag = DAG('{}', tags=['{}'])"

with create_session() as session:
for dag_id, tag in list(zip(TEST_FILTER_DAG_IDS, TEST_TAGS)):
for dag_id, tag in zip(TEST_FILTER_DAG_IDS, TEST_TAGS):
filename = os.path.join(tmpdir, f"{dag_id}.py")
with open(filename, "w") as f:
f.writelines(dag_contents_template.format(dag_id, tag))
Expand All @@ -169,7 +169,7 @@ def working_dags_with_read_perm(tmpdir):
"access_control={{'role_single_dag':{{'can_read'}}}}) "
)
with create_session() as session:
for dag_id, tag in list(zip(TEST_FILTER_DAG_IDS, TEST_TAGS)):
for dag_id, tag in zip(TEST_FILTER_DAG_IDS, TEST_TAGS):
filename = os.path.join(tmpdir, f"{dag_id}.py")
if dag_id == "filter_test_1":
with open(filename, "w") as f:
Expand All @@ -188,7 +188,7 @@ def working_dags_with_edit_perm(tmpdir):
"access_control={{'role_single_dag':{{'can_edit'}}}}) "
)
with create_session() as session:
for dag_id, tag in list(zip(TEST_FILTER_DAG_IDS, TEST_TAGS)):
for dag_id, tag in zip(TEST_FILTER_DAG_IDS, TEST_TAGS):
filename = os.path.join(tmpdir, f"{dag_id}.py")
if dag_id == "filter_test_1":
with open(filename, "w") as f:
Expand Down

0 comments on commit d2c0bbe

Please sign in to comment.