Skip to content

Commit

Permalink
openlineage: don't run task instance listener in executor
Browse files Browse the repository at this point in the history
Signed-off-by: Maciej Obuchowski <[email protected]>
  • Loading branch information
mobuchowski committed Aug 18, 2023
1 parent 744aa60 commit c9de188
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 10 deletions.
20 changes: 11 additions & 9 deletions airflow/providers/openlineage/plugins/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

import logging
from concurrent.futures import Executor, ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import TYPE_CHECKING

Expand All @@ -42,8 +42,8 @@ class OpenLineageListener:
"""OpenLineage listener sends events on task instance and dag run starts, completes and failures."""

def __init__(self):
self._executor = None
self.log = logging.getLogger(__name__)
self.executor: Executor = None # type: ignore
self.extractor_manager = ExtractorManager()
self.adapter = OpenLineageAdapter()

Expand Down Expand Up @@ -102,7 +102,7 @@ def on_running():
},
)

self.executor.submit(on_running)
on_running()

@hookimpl
def on_task_instance_success(self, previous_state, task_instance: TaskInstance, session):
Expand Down Expand Up @@ -130,7 +130,7 @@ def on_success():
task=task_metadata,
)

self.executor.submit(on_success)
on_success()

@hookimpl
def on_task_instance_failed(self, previous_state, task_instance: TaskInstance, session):
Expand Down Expand Up @@ -158,12 +158,17 @@ def on_failure():
task=task_metadata,
)

self.executor.submit(on_failure)
on_failure()

@property
def executor(self):
if not self._executor:
self._executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="openlineage_")
return self._executor

@hookimpl
def on_starting(self, component):
self.log.debug("on_starting: %s", component.__class__.__name__)
self.executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="openlineage_")

@hookimpl
def before_stopping(self, component):
Expand All @@ -174,9 +179,6 @@ def before_stopping(self, component):

@hookimpl
def on_dag_run_running(self, dag_run: DagRun, msg: str):
if not self.executor:
self.log.error("Executor have not started before `on_dag_run_running`")
return
data_interval_start = dag_run.data_interval_start.isoformat() if dag_run.data_interval_start else None
data_interval_end = dag_run.data_interval_end.isoformat() if dag_run.data_interval_end else None
self.executor.submit(
Expand Down
41 changes: 41 additions & 0 deletions tests/dags/test_dag_xcom_openlineage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
##
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import datetime

from airflow.models import DAG
from airflow.operators.python import PythonOperator

dag = DAG(
dag_id="test_dag_xcom_openlineage",
default_args={"owner": "airflow", "retries": 3, "start_date": datetime.datetime(2022, 1, 1)},
schedule="0 0 * * *",
dagrun_timeout=datetime.timedelta(minutes=60),
)


def push_and_pull(ti, **kwargs):
ti.xcom_push(key="pushed_key", value="asdf")
ti.xcom_pull(key="pushed_key")


task = PythonOperator(task_id="push_and_pull", python_callable=push_and_pull, dag=dag)

if __name__ == "__main__":
dag.cli()
4 changes: 4 additions & 0 deletions tests/listeners/test_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations

import os

import pytest as pytest

from airflow import AirflowException
Expand Down Expand Up @@ -46,6 +48,8 @@
TASK_ID = "test_listener_task"
EXECUTION_DATE = timezone.utcnow()

TEST_DAG_FOLDER = os.environ["AIRFLOW__CORE__DAGS_FOLDER"]


@pytest.fixture(autouse=True)
def clean_listener_manager():
Expand Down
46 changes: 46 additions & 0 deletions tests/listeners/xcom_listener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from airflow.listeners import hookimpl


class XComListener:
def __init__(self, path: str, task_id: str):
self.path = path
self.task_id = task_id

def write(self, line: str):
with open(self.path, "a") as f:
f.write(line + "\n")

@hookimpl
def on_task_instance_running(self, previous_state, task_instance, session):
task_instance.xcom_push(key="listener", value="listener")
task_instance.xcom_pull(task_ids=task_instance.task_id, key="listener")
self.write("on_task_instance_running")

@hookimpl
def on_task_instance_success(self, previous_state, task_instance, session):
read = task_instance.xcom_pull(task_ids=self.task_id, key="listener")
self.write("on_task_instance_success")
self.write(read)


def clear():
pass
56 changes: 55 additions & 1 deletion tests/task/task_runner/test_standard_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from airflow.utils.platform import getuser
from airflow.utils.state import State
from airflow.utils.timeout import timeout
from tests.listeners import xcom_listener
from tests.listeners.file_write_listener import FileWriteListener
from tests.test_utils.db import clear_db_runs

Expand Down Expand Up @@ -85,10 +86,14 @@ def setup_class(self):
(as the test environment does not have enough context for the normal
way to run) and ensures they reset back to normal on the way out.
"""
get_listener_manager().clear()
clear_db_runs()
yield
clear_db_runs()

@pytest.fixture(autouse=True)
def clean_listener_manager(self):
get_listener_manager().clear()
yield
get_listener_manager().clear()

@patch("airflow.utils.log.file_task_handler.FileTaskHandler._init_file")
Expand Down Expand Up @@ -215,6 +220,55 @@ def test_notifies_about_fail(self):
assert f.readline() == "on_task_instance_failed\n"
assert f.readline() == "before_stopping\n"

def test_ol_does_not_block_xcoms(self):
"""
Test that ensures that pushing and pulling xcoms both in listener and task does not collide
"""

path_listener_writer = "/tmp/test_ol_does_not_block_xcoms"
try:
os.unlink(path_listener_writer)
except OSError:
pass

listener = xcom_listener.XComListener(path_listener_writer, "push_and_pull")
get_listener_manager().add_listener(listener)

dagbag = DagBag(
dag_folder=TEST_DAG_FOLDER,
include_examples=False,
)
dag = dagbag.dags.get("test_dag_xcom_openlineage")
task = dag.get_task("push_and_pull")
dag.create_dagrun(
run_id="test",
data_interval=(DEFAULT_DATE, DEFAULT_DATE),
state=State.RUNNING,
start_date=DEFAULT_DATE,
)

ti = TaskInstance(task=task, run_id="test")
job = Job(dag_id=ti.dag_id)
job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True)
task_runner = StandardTaskRunner(job_runner)
task_runner.start()

# Wait until process makes itself the leader of its own process group
with timeout(seconds=1):
while True:
runner_pgid = os.getpgid(task_runner.process.pid)
if runner_pgid == task_runner.process.pid:
break
time.sleep(0.01)

# Wait till process finishes
assert task_runner.return_code(timeout=10) is not None

with open(path_listener_writer) as f:
assert f.readline() == "on_task_instance_running\n"
assert f.readline() == "on_task_instance_success\n"
assert f.readline() == "listener\n"

@patch("airflow.utils.log.file_task_handler.FileTaskHandler._init_file")
def test_start_and_terminate_run_as_user(self, mock_init):
mock_init.return_value = "/tmp/any"
Expand Down

0 comments on commit c9de188

Please sign in to comment.