Skip to content

Commit

Permalink
Independently update source code changes in DagCode (apache#44189)
Browse files Browse the repository at this point in the history
* Independently update source code changes in DagCode

When the source code changes but not structural changes that would
trigger a new version, we should update the DagCode's source code.

To do that, I removed the fileloc_hash, which is no longer necessary
as we read the code from the DB. I also added the source_code_hash
column, which is used to detect code changes and update the source code.

* fixup! Independently update source code changes in DagCode

* fixup! fixup! Independently update source code changes in DagCode

* fixup! fixup! fixup! Independently update source code changes in DagCode

* fixup! fixup! fixup! fixup! Independently update source code changes in DagCode
  • Loading branch information
ephraimbuddy authored and Lefteris Gilmaz committed Jan 5, 2025
1 parent fa27dff commit 6feae13
Show file tree
Hide file tree
Showing 8 changed files with 1,015 additions and 964 deletions.
10 changes: 6 additions & 4 deletions airflow/migrations/versions/0047_3_0_0_add_dag_versioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ def upgrade():
ondelete="CASCADE",
)
batch_op.create_unique_constraint("dag_code_dag_version_id_uq", ["dag_version_id"])
batch_op.drop_column("last_updated")
batch_op.add_column(sa.Column("created_at", UtcDateTime(), nullable=False, default=timezone.utcnow))
batch_op.add_column(sa.Column("source_code_hash", sa.String(length=32), nullable=False))
batch_op.drop_column("fileloc_hash")
batch_op.add_column(sa.Column("dag_id", sa.String(length=250), nullable=False))

with op.batch_alter_table(
"serialized_dag", recreate="always", naming_convention=naming_convention
Expand Down Expand Up @@ -142,9 +143,10 @@ def downgrade():
batch_op.drop_column("id")
batch_op.drop_constraint(batch_op.f("dag_code_dag_version_id_fkey"), type_="foreignkey")
batch_op.drop_column("dag_version_id")
batch_op.add_column(sa.Column("fileloc_hash", sa.BigInteger, nullable=False))
batch_op.create_primary_key("dag_code_pkey", ["fileloc_hash"])
batch_op.drop_column("created_at")
batch_op.add_column(sa.Column("last_updated", UtcDateTime(), nullable=False))
batch_op.drop_column("source_code_hash")
batch_op.drop_column("dag_id")

with op.batch_alter_table("serialized_dag", schema=None, naming_convention=naming_convention) as batch_op:
batch_op.drop_column("id")
Expand Down
4 changes: 4 additions & 0 deletions airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
)
from airflow.listeners.listener import get_listener_manager
from airflow.models.base import Base
from airflow.models.dagcode import DagCode
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.dag_cycle_tester import check_cycle
Expand Down Expand Up @@ -626,6 +627,9 @@ def _serialize_dag_capturing_errors(dag, session, processor_subdir):
)
if dag_was_updated:
DagBag._sync_perm_for_dag(dag, session=session)
else:
# Check and update DagCode
DagCode.update_source_code(dag)
return []
except OperationalError:
raise
Expand Down
114 changes: 66 additions & 48 deletions airflow/models/dagcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,29 @@
from __future__ import annotations

import logging
import struct
from typing import TYPE_CHECKING

import uuid6
from sqlalchemy import BigInteger, Column, ForeignKey, String, Text, select
from sqlalchemy import Column, ForeignKey, String, Text, select
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.orm import relationship
from sqlalchemy.sql.expression import literal
from sqlalchemy_utils import UUIDType

from airflow.configuration import conf
from airflow.exceptions import DagCodeNotFound
from airflow.models.base import Base
from airflow.models.base import ID_LEN, Base
from airflow.utils import timezone
from airflow.utils.file import open_maybe_zipped
from airflow.utils.hashlib_wrapper import md5
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime

if TYPE_CHECKING:
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select

from airflow.models.dag import DAG
from airflow.models.dag_version import DagVersion

log = logging.getLogger(__name__)
Expand All @@ -54,11 +56,12 @@ class DagCode(Base):

__tablename__ = "dag_code"
id = Column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7)
fileloc_hash = Column(BigInteger, nullable=False)
dag_id = Column(String(ID_LEN), nullable=False)
fileloc = Column(String(2000), nullable=False)
# The max length of fileloc exceeds the limit of indexing.
created_at = Column(UtcDateTime, nullable=False, default=timezone.utcnow)
last_updated = Column(UtcDateTime, nullable=False, default=timezone.utcnow, onupdate=timezone.utcnow)
source_code = Column(Text().with_variant(MEDIUMTEXT(), "mysql"), nullable=False)
source_code_hash = Column(String(32), nullable=False)
dag_version_id = Column(
UUIDType(binary=False), ForeignKey("dag_version.id", ondelete="CASCADE"), nullable=False, unique=True
)
Expand All @@ -67,9 +70,9 @@ class DagCode(Base):
def __init__(self, dag_version, full_filepath: str, source_code: str | None = None):
self.dag_version = dag_version
self.fileloc = full_filepath
self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc)
self.last_updated = timezone.utcnow()
self.source_code = source_code or DagCode.code(self.fileloc)
self.source_code = source_code or DagCode.code(self.dag_version.dag_id)
self.source_code_hash = self.dag_source_hash(self.source_code)
self.dag_id = dag_version.dag_id

@classmethod
@provide_session
Expand All @@ -81,50 +84,37 @@ def write_code(cls, dag_version: DagVersion, fileloc: str, session: Session = NE
:param session: ORM Session
"""
log.debug("Writing DAG file %s into DagCode table", fileloc)
dag_code = DagCode(dag_version, fileloc, cls._get_code_from_file(fileloc))
dag_code = DagCode(dag_version, fileloc, cls.get_code_from_file(fileloc))
session.add(dag_code)
log.debug("DAG file %s written into DagCode table", fileloc)
return dag_code

@classmethod
@provide_session
def has_dag(cls, fileloc: str, session: Session = NEW_SESSION) -> bool:
def has_dag(cls, dag_id: str, session: Session = NEW_SESSION) -> bool:
"""
Check a file exist in dag_code table.
Check a dag exists in dag code table.
:param fileloc: the file to check
:param dag_id: the dag_id of the DAG
:param session: ORM Session
"""
fileloc_hash = cls.dag_fileloc_hash(fileloc)
return (
session.scalars(
select(literal(True)).where(cls.fileloc_hash == fileloc_hash).limit(1)
).one_or_none()
session.scalars(select(literal(True)).where(cls.dag_id == dag_id).limit(1)).one_or_none()
is not None
)

@classmethod
def get_code_by_fileloc(cls, fileloc: str) -> str:
"""
Return source code for a given fileloc.
:param fileloc: file path of a DAG
:return: source code as string
"""
return cls.code(fileloc)

@classmethod
@provide_session
def code(cls, fileloc, session: Session = NEW_SESSION) -> str:
def code(cls, dag_id, session: Session = NEW_SESSION) -> str:
"""
Return source code for this DagCode object.
:return: source code as string
"""
return cls._get_code_from_db(fileloc, session)
return cls._get_code_from_db(dag_id, session)

@staticmethod
def _get_code_from_file(fileloc):
def get_code_from_file(fileloc):
try:
with open_maybe_zipped(fileloc, "r") as f:
code = f.read()
Expand All @@ -137,12 +127,9 @@ def _get_code_from_file(fileloc):

@classmethod
@provide_session
def _get_code_from_db(cls, fileloc, session: Session = NEW_SESSION) -> str:
def _get_code_from_db(cls, dag_id, session: Session = NEW_SESSION) -> str:
dag_code = session.scalar(
select(cls)
.where(cls.fileloc_hash == cls.dag_fileloc_hash(fileloc))
.order_by(cls.created_at.desc())
.limit(1)
select(cls).where(cls.dag_id == dag_id).order_by(cls.last_updated.desc()).limit(1)
)
if not dag_code:
raise DagCodeNotFound()
Expand All @@ -151,21 +138,52 @@ def _get_code_from_db(cls, fileloc, session: Session = NEW_SESSION) -> str:
return code

@staticmethod
def dag_fileloc_hash(full_filepath: str) -> int:
def dag_source_hash(source: str) -> str:
"""
Hashing file location for indexing.
Hash the source code of the DAG.
:param full_filepath: full filepath of DAG file
:return: hashed full_filepath
This is needed so we can update the source on code changes
"""
# Hashing is needed because the length of fileloc is 2000 as an Airflow convention,
# which is over the limit of indexing.
import hashlib
return md5(source.encode("utf-8")).hexdigest()

# Only 7 bytes because MySQL BigInteger can hold only 8 bytes (signed).
return (
struct.unpack(
">Q", hashlib.sha1(full_filepath.encode("utf-8"), usedforsecurity=False).digest()[-8:]
)[0]
>> 8
)
@classmethod
def _latest_dagcode_select(cls, dag_id: str) -> Select:
"""
Get the select object to get the latest dagcode.
:param dag_id: The DAG ID.
:return: The select object.
"""
return select(cls).where(cls.dag_id == dag_id).order_by(cls.last_updated.desc()).limit(1)

@classmethod
@provide_session
def get_latest_dagcode(cls, dag_id: str, session: Session = NEW_SESSION) -> DagCode | None:
"""
Get the latest dagcode.
:param dag_id: The DAG ID.
:param session: The database session.
:return: The latest dagcode or None if not found.
"""
return session.scalar(cls._latest_dagcode_select(dag_id))

@classmethod
@provide_session
def update_source_code(cls, dag: DAG, session: Session = NEW_SESSION) -> None:
"""
Check if the source code of the DAG has changed and update it if needed.
:param dag: The DAG object.
:param session: The database session.
:return: None
"""
latest_dagcode = cls.get_latest_dagcode(dag.dag_id, session)
if not latest_dagcode:
return
new_source_code = cls.get_code_from_file(dag.fileloc)
new_source_code_hash = cls.dag_source_hash(new_source_code)
if new_source_code_hash != latest_dagcode.source_code_hash:
latest_dagcode.source_code = new_source_code
latest_dagcode.source_code_hash = new_source_code_hash
session.merge(latest_dagcode)
2 changes: 1 addition & 1 deletion docs/apache-airflow/img/airflow_erd.sha256
Original file line number Diff line number Diff line change
@@ -1 +1 @@
7748eec981f977cc97b852d1fe982aebe24ec2d090ae8493a65cea101f9d42a5
5042271e47bcf1160477200adae4c42ce1cecacf5cbbe7e334d6268debe857fb
Loading

0 comments on commit 6feae13

Please sign in to comment.