Skip to content

Commit

Permalink
Refactor dynamic task mapping implementation (#313)
Browse files Browse the repository at this point in the history
Currently DAG Factory dynamic task mapping implementation relies in
internal Airflow implementation details (that it uses XComArg). If
Airflow changes the implementation, DAG factory will stop working.

With this PR, we're refactoring to use the `task.output` parameter that
is closer to the Airflow public API.

This was a suggestion by @ashb while we were discussing the overall DAG
Factory dynamic task implementation.
  • Loading branch information
tatiana authored Dec 6, 2024
1 parent a6bf015 commit 80b885e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
5 changes: 2 additions & 3 deletions dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@
else:
MappedOperator = None

from airflow.models.xcom_arg import XComArg

if version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0"):
from airflow.datasets import Dataset
Expand Down Expand Up @@ -682,11 +681,11 @@ def replace_expand_values(task_conf: Dict, tasks_dict: Dict[str, BaseOperator]):
if ".output" in expand_value:
task_id = expand_value.split(".output")[0]
if task_id in tasks_dict:
task_conf["expand"][expand_key] = XComArg(tasks_dict[task_id])
task_conf["expand"][expand_key] = tasks_dict[task_id].output
elif "XcomArg" in expand_value:
task_id = re.findall(r"\(+(.*?)\)", expand_value)[0]
if task_id in tasks_dict:
task_conf["expand"][expand_key] = XComArg(tasks_dict[task_id])
task_conf["expand"][expand_key] = tasks_dict[task_id].output
return task_conf

# pylint: disable=too-many-locals
Expand Down
9 changes: 7 additions & 2 deletions tests/test_dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,6 @@ def test_dynamic_task_mapping():
assert isinstance(actual, MappedOperator)


@patch("dagfactory.dagbuilder.PythonOperator", new=MockPythonOperator)
def test_replace_expand_string_with_xcom():
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_DYNAMIC_TASK_MAPPING, DEFAULT_CONFIG)
if version.parse(AIRFLOW_VERSION) < version.parse("2.3.0"):
Expand All @@ -1000,7 +999,13 @@ def test_replace_expand_string_with_xcom():

task_conf_output = {"expand": {"key_1": "task_1.output"}}
task_conf_xcomarg = {"expand": {"key_1": "XcomArg(task_1)"}}
tasks_dict = {"task_1": MockPythonOperator()}

task1 = PythonOperator(
task_id="task1",
python_callable=lambda: print("hello"),
)

tasks_dict = {"task_1": task1}
updated_task_conf_output = dagbuilder.DagBuilder.replace_expand_values(task_conf_output, tasks_dict)
updated_task_conf_xcomarg = dagbuilder.DagBuilder.replace_expand_values(task_conf_xcomarg, tasks_dict)
assert updated_task_conf_output["expand"]["key_1"] == XComArg(tasks_dict["task_1"])
Expand Down

0 comments on commit 80b885e

Please sign in to comment.