Skip to content

Commit

Permalink
Move Pydantic class serialization under AIP-44 feature flag (#30560)
Browse files Browse the repository at this point in the history
The Pydantic representation of the ORM models is only used
in AIP-44 in-progress feature, and we are moving to a new
seialization implementation (more modular) in a near feature
so in order to not unecessarily extend features in old
serialization, but allow to test AIP-44, we are moving the
use_pydantic_models parameter and it's implementation under
_ENABLE_AIP_44 feature flag, so that it is not used accidentally.

We will eventually remove it and add Pydantic serialization to
the new serialization implementation.
  • Loading branch information
potiuk authored Apr 10, 2023
1 parent 18ec7f2 commit e8da514
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 21 deletions.
60 changes: 39 additions & 21 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.serialization.pydantic.dataset import DatasetPydantic
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.settings import DAGS_FOLDER, json
from airflow.settings import _ENABLE_AIP_44, DAGS_FOLDER, json
from airflow.timetables.base import Timetable
from airflow.utils.code_utils import get_python_source
from airflow.utils.docs import get_docs_url
Expand Down Expand Up @@ -404,6 +404,11 @@ def serialize(
:meta private:
"""
if use_pydantic_models and not _ENABLE_AIP_44:
raise RuntimeError(
"Setting use_pydantic_models = True requires AIP-44 (in progress) feature flag to be true. "
"This parameter will be removed eventually when new serialization is used by AIP-44"
)
if cls._is_primitive(var):
# enum.IntEnum is an int instance, it causes json dumps error so we use its value.
if isinstance(var, enum.Enum):
Expand Down Expand Up @@ -473,19 +478,26 @@ def serialize(
cls.serialize(var.__dict__, strict=strict, use_pydantic_models=use_pydantic_models),
type_=DAT.SIMPLE_TASK_INSTANCE,
)
elif use_pydantic_models and isinstance(var, BaseJob):
return cls._encode(BaseJobPydantic.from_orm(var).dict(), type_=DAT.BASE_JOB)
elif use_pydantic_models and isinstance(var, TaskInstance):
return cls._encode(TaskInstancePydantic.from_orm(var).dict(), type_=DAT.TASK_INSTANCE)
elif use_pydantic_models and isinstance(var, DagRun):
return cls._encode(DagRunPydantic.from_orm(var).dict(), type_=DAT.DAG_RUN)
elif use_pydantic_models and isinstance(var, Dataset):
return cls._encode(DatasetPydantic.from_orm(var).dict(), type_=DAT.DATA_SET)
elif use_pydantic_models and _ENABLE_AIP_44:
if isinstance(var, BaseJob):
return cls._encode(BaseJobPydantic.from_orm(var).dict(), type_=DAT.BASE_JOB)
elif isinstance(var, TaskInstance):
return cls._encode(TaskInstancePydantic.from_orm(var).dict(), type_=DAT.TASK_INSTANCE)
elif isinstance(var, DagRun):
return cls._encode(DagRunPydantic.from_orm(var).dict(), type_=DAT.DAG_RUN)
elif isinstance(var, Dataset):
return cls._encode(DatasetPydantic.from_orm(var).dict(), type_=DAT.DATA_SET)
else:
return cls.default_serialization(strict, var)
else:
log.debug("Cast type %s to str in serialization.", type(var))
if strict:
raise SerializationError("Encountered unexpected type")
return str(var)
return cls.default_serialization(strict, var)

@classmethod
def default_serialization(cls, strict, var) -> str:
log.debug("Cast type %s to str in serialization.", type(var))
if strict:
raise SerializationError("Encountered unexpected type")
return str(var)

@classmethod
def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
Expand All @@ -494,6 +506,11 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
:meta private:
"""
# JSON primitives (except for dict) are not encoded.
if use_pydantic_models and not _ENABLE_AIP_44:
raise RuntimeError(
"Setting use_pydantic_models = True requires AIP-44 (in progress) feature flag to be true. "
"This parameter will be removed eventually when new serialization is used by AIP-44"
)
if cls._is_primitive(encoded_var):
return encoded_var
elif isinstance(encoded_var, list):
Expand Down Expand Up @@ -535,14 +552,15 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
return Dataset(**var)
elif type_ == DAT.SIMPLE_TASK_INSTANCE:
return SimpleTaskInstance(**cls.deserialize(var))
elif use_pydantic_models and type_ == DAT.BASE_JOB:
return BaseJobPydantic.parse_obj(var)
elif use_pydantic_models and type_ == DAT.TASK_INSTANCE:
return TaskInstancePydantic.parse_obj(var)
elif use_pydantic_models and type_ == DAT.DAG_RUN:
return DagRunPydantic.parse_obj(var)
elif use_pydantic_models and type_ == DAT.DATA_SET:
return DatasetPydantic.parse_obj(var)
elif use_pydantic_models and _ENABLE_AIP_44:
if type_ == DAT.BASE_JOB:
return BaseJobPydantic.parse_obj(var)
elif type_ == DAT.TASK_INSTANCE:
return TaskInstancePydantic.parse_obj(var)
elif type_ == DAT.DAG_RUN:
return DagRunPydantic.parse_obj(var)
elif type_ == DAT.DATA_SET:
return DatasetPydantic.parse_obj(var)
else:
raise TypeError(f"Invalid type {type_!s} in deserialization.")

Expand Down
2 changes: 2 additions & 0 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from airflow.models.taskinstance import TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.settings import _ENABLE_AIP_44
from airflow.utils.state import State
from tests import REPO_ROOT

Expand Down Expand Up @@ -82,6 +83,7 @@ class Test:
BaseSerialization.serialize(obj, strict=True) # now raises


@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled")
def test_use_pydantic_models():
"""If use_pydantic_models=True the TaskInstance object should be serialized to TaskInstancePydantic."""

Expand Down

0 comments on commit e8da514

Please sign in to comment.