From c7a34df89c94609a4fc45734e9a48b126985d71e Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 6 Sep 2024 14:41:30 +0800 Subject: [PATCH] Rewrite how DAG to dataset / dataset alias are stored (#41987) (#42055) --- airflow/models/dag.py | 88 ++++++++++++++++++++++++------------------- 1 file changed, 50 insertions(+), 38 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index f848346780857..58213efeecff3 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -3236,8 +3236,6 @@ def bulk_write_to_db( if not dags: return - from airflow.models.dataset import DagScheduleDatasetAliasReference - log.info("Sync %s DAGs", len(dags)) dag_by_ids = {dag.dag_id: dag for dag in dags} @@ -3344,18 +3342,19 @@ def bulk_write_to_db( from airflow.datasets import Dataset from airflow.models.dataset import ( + DagScheduleDatasetAliasReference, DagScheduleDatasetReference, DatasetModel, TaskOutletDatasetReference, ) - dag_references: dict[str, set[Dataset | DatasetAlias]] = defaultdict(set) + dag_references: dict[str, set[tuple[Literal["dataset", "dataset-alias"], str]]] = defaultdict(set) outlet_references = defaultdict(set) # We can't use a set here as we want to preserve order - outlet_datasets: dict[DatasetModel, None] = {} - input_datasets: dict[DatasetModel, None] = {} + outlet_dataset_models: dict[DatasetModel, None] = {} + input_dataset_models: dict[DatasetModel, None] = {} outlet_dataset_alias_models: set[DatasetAliasModel] = set() - input_dataset_aliases: set[DatasetAliasModel] = set() + input_dataset_alias_models: set[DatasetAliasModel] = set() # here we go through dags and tasks to check for dataset references # if there are now None and previously there were some, we delete them @@ -3371,12 +3370,12 @@ def bulk_write_to_db( curr_orm_dag.schedule_dataset_alias_references = [] else: for _, dataset in dataset_condition.iter_datasets(): - dag_references[dag.dag_id].add(Dataset(uri=dataset.uri)) - input_datasets[DatasetModel.from_public(dataset)] = None + dag_references[dag.dag_id].add(("dataset", dataset.uri)) + input_dataset_models[DatasetModel.from_public(dataset)] = None for dataset_alias in dataset_condition.iter_dataset_aliases(): - dag_references[dag.dag_id].add(dataset_alias) - input_dataset_aliases.add(DatasetAliasModel.from_public(dataset_alias)) + dag_references[dag.dag_id].add(("dataset-alias", dataset_alias.name)) + input_dataset_alias_models.add(DatasetAliasModel.from_public(dataset_alias)) curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references for task in dag.tasks: @@ -3399,63 +3398,70 @@ def bulk_write_to_db( curr_outlet_references.remove(ref) for d in dataset_outlets: + outlet_dataset_models[DatasetModel.from_public(d)] = None outlet_references[(task.dag_id, task.task_id)].add(d.uri) - outlet_datasets[DatasetModel.from_public(d)] = None for d_a in dataset_alias_outlets: outlet_dataset_alias_models.add(DatasetAliasModel.from_public(d_a)) - all_datasets = outlet_datasets - all_datasets.update(input_datasets) + all_dataset_models = outlet_dataset_models + all_dataset_models.update(input_dataset_models) # store datasets - stored_datasets: dict[str, DatasetModel] = {} - new_datasets: list[DatasetModel] = [] - for dataset in all_datasets: - stored_dataset = session.scalar( + stored_dataset_models: dict[str, DatasetModel] = {} + new_dataset_models: list[DatasetModel] = [] + for dataset in all_dataset_models: + stored_dataset_model = session.scalar( select(DatasetModel).where(DatasetModel.uri == dataset.uri).limit(1) ) - if stored_dataset: + if stored_dataset_model: # Some datasets may have been previously unreferenced, and therefore orphaned by the # scheduler. But if we're here, then we have found that dataset again in our DAGs, which # means that it is no longer an orphan, so set is_orphaned to False. - stored_dataset.is_orphaned = expression.false() - stored_datasets[stored_dataset.uri] = stored_dataset + stored_dataset_model.is_orphaned = expression.false() + stored_dataset_models[stored_dataset_model.uri] = stored_dataset_model else: - new_datasets.append(dataset) - dataset_manager.create_datasets(dataset_models=new_datasets, session=session) - stored_datasets.update({dataset.uri: dataset for dataset in new_datasets}) + new_dataset_models.append(dataset) + dataset_manager.create_datasets(dataset_models=new_dataset_models, session=session) + stored_dataset_models.update( + {dataset_model.uri: dataset_model for dataset_model in new_dataset_models} + ) - del new_datasets - del all_datasets + del new_dataset_models + del all_dataset_models # store dataset aliases - all_datasets_alias_models = input_dataset_aliases | outlet_dataset_alias_models - stored_dataset_aliases: dict[str, DatasetAliasModel] = {} + all_datasets_alias_models = input_dataset_alias_models | outlet_dataset_alias_models + stored_dataset_alias_models: dict[str, DatasetAliasModel] = {} new_dataset_alias_models: set[DatasetAliasModel] = set() if all_datasets_alias_models: - all_dataset_alias_names = {dataset_alias.name for dataset_alias in all_datasets_alias_models} + all_dataset_alias_names = { + dataset_alias_model.name for dataset_alias_model in all_datasets_alias_models + } - stored_dataset_aliases = { + stored_dataset_alias_models = { dsa_m.name: dsa_m for dsa_m in session.scalars( select(DatasetAliasModel).where(DatasetAliasModel.name.in_(all_dataset_alias_names)) ).fetchall() } - if stored_dataset_aliases: + if stored_dataset_alias_models: new_dataset_alias_models = { dataset_alias_model for dataset_alias_model in all_datasets_alias_models - if dataset_alias_model.name not in stored_dataset_aliases.keys() + if dataset_alias_model.name not in stored_dataset_alias_models.keys() } else: new_dataset_alias_models = all_datasets_alias_models session.add_all(new_dataset_alias_models) session.flush() - stored_dataset_aliases.update( - {dataset_alias.name: dataset_alias for dataset_alias in new_dataset_alias_models} + stored_dataset_alias_models.update( + { + dataset_alias_model.name: dataset_alias_model + for dataset_alias_model in new_dataset_alias_models + } ) del new_dataset_alias_models @@ -3464,14 +3470,18 @@ def bulk_write_to_db( # reconcile dag-schedule-on-dataset and dag-schedule-on-dataset-alias references for dag_id, base_dataset_list in dag_references.items(): dag_refs_needed = { - DagScheduleDatasetReference(dataset_id=stored_datasets[base_dataset.uri].id, dag_id=dag_id) - if isinstance(base_dataset, Dataset) + DagScheduleDatasetReference( + dataset_id=stored_dataset_models[base_dataset_identifier].id, dag_id=dag_id + ) + if base_dataset_type == "dataset" else DagScheduleDatasetAliasReference( - alias_id=stored_dataset_aliases[base_dataset.name].id, dag_id=dag_id + alias_id=stored_dataset_alias_models[base_dataset_identifier].id, dag_id=dag_id ) - for base_dataset in base_dataset_list + for base_dataset_type, base_dataset_identifier in base_dataset_list } + # if isinstance(base_dataset, Dataset) + dag_refs_stored = ( set(existing_dags.get(dag_id).schedule_dataset_references) # type: ignore | set(existing_dags.get(dag_id).schedule_dataset_alias_references) # type: ignore @@ -3491,7 +3501,9 @@ def bulk_write_to_db( # reconcile task-outlet-dataset references for (dag_id, task_id), uri_list in outlet_references.items(): task_refs_needed = { - TaskOutletDatasetReference(dataset_id=stored_datasets[uri].id, dag_id=dag_id, task_id=task_id) + TaskOutletDatasetReference( + dataset_id=stored_dataset_models[uri].id, dag_id=dag_id, task_id=task_id + ) for uri in uri_list } task_refs_stored = existing_task_outlet_refs_dict[(dag_id, task_id)]