diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index c8584a9f4a567..5efca124e8852 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -154,37 +154,38 @@ def get_direct_relative_ids(self, upstream: bool = False) -> set[str]: return self.upstream_task_ids return self.downstream_task_ids - def get_flat_relative_ids( - self, - upstream: bool = False, - found_descendants: set[str] | None = None, - ) -> set[str]: - """Get a flat set of relative IDs, upstream or downstream.""" + def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]: + """ + Get a flat set of relative IDs, upstream or downstream. + + Will recurse each relative found in the direction specified. + + :param upstream: Whether to look for upstream or downstream relatives. + """ dag = self.get_dag() if not dag: return set() - if found_descendants is None: - found_descendants = set() + relatives: set[str] = set() task_ids_to_trace = self.get_direct_relative_ids(upstream) while task_ids_to_trace: task_ids_to_trace_next: set[str] = set() for task_id in task_ids_to_trace: - if task_id in found_descendants: + if task_id in relatives: continue task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream)) - found_descendants.add(task_id) + relatives.add(task_id) task_ids_to_trace = task_ids_to_trace_next - return found_descendants + return relatives def get_flat_relatives(self, upstream: bool = False) -> Collection[Operator]: """Get a flat list of relatives, either upstream or downstream.""" dag = self.get_dag() if not dag: return set() - return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream)] + return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream)] def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: """Return mapped nodes that are direct dependencies of the current task.