Skip to content

Commit

Permalink
Fix timing-based flakey test in TestLocalTaskJob (#8405)
Browse files Browse the repository at this point in the history
This test suffered from timing-based failures, if the "main" process
took even fractionally too long then the task process would have already
cleaned up it's subprocess, so the expected callback in the main/test
process would never be run.

This changes is so that the callback _will always be called_ in the test
process if it is called at all.

GitOrigin-RevId: d06d3165ff7df8ceeb52f8f18154f5f27d83355b
  • Loading branch information
ashb authored and Cloud Composer Team committed Sep 11, 2024
1 parent c535aa9 commit fd8b14b
Showing 1 changed file with 30 additions and 27 deletions.
57 changes: 30 additions & 27 deletions tests/jobs/test_local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@
from airflow.models.dagbag import DagBag
from airflow.models.taskinstance import TaskInstance
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python import PythonOperator
from airflow.utils import timezone
from airflow.utils.net import get_hostname
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.timeout import timeout
from tests.test_utils.db import clear_db_runs
from tests.test_utils.mock_executor import MockExecutor

Expand Down Expand Up @@ -302,18 +304,27 @@ def test_mark_failure_on_failure_callback(self):
data = {'called': False}

def check_failure(context):
self.assertEqual(context['dag_run'].dag_id,
'test_mark_failure')
self.assertEqual(context['dag_run'].dag_id, 'test_mark_failure')
data['called'] = True

dag = DAG(dag_id='test_mark_failure',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})

task = DummyOperator(
task_id='test_state_succeeded1',
dag=dag,
on_failure_callback=check_failure)
def task_function(ti):
print("python_callable run in pid %s", os.getpid())
with create_session() as session:
self.assertEqual(State.RUNNING, ti.state)
ti.log.info("Marking TI as failed 'externally'")
ti.state = State.FAILED
session.merge(ti)
session.commit()

time.sleep(60)
# This should not happen -- the state change should be noticed and the task should get killed
data['reached_end_of_sleep'] = True

with DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE) as dag:
task = PythonOperator(
task_id='test_state_succeeded1',
python_callable=task_function,
on_failure_callback=check_failure)

session = settings.Session()

Expand All @@ -325,28 +336,20 @@ def check_failure(context):
session=session)
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()

job1 = LocalTaskJob(task_instance=ti,
ignore_ti_state=True,
executor=SequentialExecutor())
from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
job1.task_runner = StandardTaskRunner(job1)
process = multiprocessing.Process(target=job1.run)
process.start()
ti.refresh_from_db()
for _ in range(0, 50):
if ti.state == State.RUNNING:
break
time.sleep(0.1)
ti.refresh_from_db()
self.assertEqual(State.RUNNING, ti.state)
ti.state = State.FAILED
session.merge(ti)
session.commit()
with timeout(30):
# This should be _much_ shorter to run.
# If you change this limit, make the timeout in the callbable above bigger
job1.run()

job1.heartbeat_callback(session=None)
ti.refresh_from_db()
self.assertEqual(ti.state, State.FAILED)
self.assertTrue(data['called'])
process.join(timeout=10)
self.assertFalse(process.is_alive())
self.assertNotIn('reached_end_of_sleep', data,
'Task should not have been allowed to run to completion')

def test_mark_success_on_success_callback(self):
"""
Expand Down

0 comments on commit fd8b14b

Please sign in to comment.