Skip to content

Commit

Permalink
[AIRFLOW-5946] DAG Serialization: Store source code in db (#7217)
Browse files Browse the repository at this point in the history
* DAG serialization improvement: the DAG's source code is now stored in the dag_code table and is queried from here when the Code view is opened for the DAG. The webserver no longer needs access to the dags folder in the shared filesystem.
  • Loading branch information
anitakar authored Mar 13, 2020
1 parent a86924f commit e146518
Show file tree
Hide file tree
Showing 22 changed files with 562 additions and 112 deletions.
6 changes: 2 additions & 4 deletions airflow/api/common/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from datetime import datetime
from typing import Optional

from airflow.configuration import conf
from airflow.exceptions import DagNotFound, DagRunNotFound, TaskNotFound
from airflow.models import DagBag, DagModel, DagRun

Expand All @@ -29,12 +30,9 @@ def check_and_get_dag(dag_id: str, task_id: Optional[str] = None) -> DagModel:
if dag_model is None:
raise DagNotFound("Dag id {} not found in DagModel".format(dag_id))

def read_store_serialized_dags():
from airflow.configuration import conf
return conf.getboolean('core', 'store_serialized_dags')
dagbag = DagBag(
dag_folder=dag_model.fileloc,
store_serialized_dags=read_store_serialized_dags()
store_serialized_dags=conf.getboolean('core', 'store_serialized_dags')
)
dag = dagbag.get_dag(dag_id) # prefetch dag if it is stored serialized
if dag_id not in dagbag.dags:
Expand Down
12 changes: 5 additions & 7 deletions airflow/api/common/experimental/get_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
# under the License.
"""Get code APIs."""
from airflow.api.common.experimental import check_and_get_dag
from airflow.exceptions import AirflowException
from airflow.www import utils as wwwutils
from airflow.exceptions import AirflowException, DagCodeNotFound
from airflow.models.dagcode import DagCode


def get_code(dag_id: str) -> str:
Expand All @@ -30,9 +30,7 @@ def get_code(dag_id: str) -> str:
dag = check_and_get_dag(dag_id=dag_id)

try:
with wwwutils.open_maybe_zipped(dag.fileloc, 'r') as file:
code = file.read()
return code
except OSError as exception:
return DagCode.get_code_by_fileloc(dag.fileloc)
except (OSError, DagCodeNotFound) as exception:
error_message = "Error {} while reading Dag id {} Code".format(str(exception), dag_id)
raise AirflowException(error_message)
raise AirflowException(error_message, exception)
10 changes: 10 additions & 0 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,16 @@
type: string
example: ~
default: "30"
- name: store_dag_code
description: |
Whether to persist DAG files code in DB.
If set to True, Webserver reads file contents from DB instead of
trying to access files in a DAG folder. Defaults to same as the
store_serialized_dags setting.
version_added: 2.0.0
type: string
example: ~
default: "%(store_serialized_dags)s"
- name: max_num_rendered_ti_fields_per_task
description: |
Maximum number of Rendered Task Instance Fields (Template Fields) per task to store
Expand Down
6 changes: 6 additions & 0 deletions airflow/config_templates/default_airflow.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,12 @@ store_serialized_dags = False
# Updating serialized DAG can not be faster than a minimum interval to reduce database write rate.
min_serialized_dag_update_interval = 30

# Whether to persist DAG files code in DB.
# If set to True, Webserver reads file contents from DB instead of
# trying to access files in a DAG folder. Defaults to same as the
# store_serialized_dags setting.
store_dag_code = %(store_serialized_dags)s

# Maximum number of Rendered Task Instance Fields (Template Fields) per task to store
# in the Database.
# When Dag Serialization is enabled (``store_serialized_dags=True``), all the template_fields
Expand Down
4 changes: 4 additions & 0 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ class DagNotFound(AirflowNotFoundException):
"""Raise when a DAG is not available in the system"""


class DagCodeNotFound(AirflowNotFoundException):
"""Raise when a DAG code is not available in the system"""


class DagRunNotFound(AirflowNotFoundException):
"""Raise when a DAG Run is not available in the system"""

Expand Down
68 changes: 68 additions & 0 deletions airflow/migrations/versions/952da73b5eff_add_dag_code_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""add dag_code table
Revision ID: 952da73b5eff
Revises: 852ae6c715af
Create Date: 2020-03-12 12:39:01.797462
"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
from airflow.models.dagcode import DagCode
from airflow.models.serialized_dag import SerializedDagModel

revision = '952da73b5eff'
down_revision = '852ae6c715af'
branch_labels = None
depends_on = None


def upgrade():
"""Apply add source code table"""
op.create_table('dag_code', # pylint: disable=no-member
sa.Column('fileloc_hash', sa.BigInteger(),
nullable=False, primary_key=True, autoincrement=False),
sa.Column('fileloc', sa.String(length=2000), nullable=False),
sa.Column('source_code', sa.UnicodeText(), nullable=False),
sa.Column('last_updated', sa.TIMESTAMP(timezone=True), nullable=False))

conn = op.get_bind()
if conn.dialect.name not in ('sqlite'):
op.drop_index('idx_fileloc_hash', 'serialized_dag')
op.alter_column(table_name='serialized_dag', column_name='fileloc_hash',
type_=sa.BigInteger(), nullable=False)
op.create_index( # pylint: disable=no-member
'idx_fileloc_hash', 'serialized_dag', ['fileloc_hash'])

sessionmaker = sa.orm.sessionmaker()
session = sessionmaker(bind=conn)
serialized_dags = session.query(SerializedDagModel).all()
for dag in serialized_dags:
dag.fileloc_hash = DagCode.dag_fileloc_hash(dag.fileloc)
session.merge(dag)
session.commit()


def downgrade():
"""Unapply add source code table"""
op.drop_table('dag_code')
4 changes: 4 additions & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from airflow.models.base import ID_LEN, Base
from airflow.models.baseoperator import BaseOperator
from airflow.models.dagbag import DagBag
from airflow.models.dagcode import DagCode
from airflow.models.dagpickle import DagPickle
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance, clear_task_instances
Expand Down Expand Up @@ -1528,6 +1529,9 @@ def bulk_sync_to_db(cls, dags: Collection["DAG"], sync_time=None, session=None):
orm_dag.tags.append(dag_tag_orm)
session.add(dag_tag_orm)

if conf.getboolean('core', 'store_dag_code', fallback=False):
DagCode.bulk_sync_to_db([dag.fileloc for dag in orm_dags])

session.commit()

for dag in dags:
Expand Down
213 changes: 213 additions & 0 deletions airflow/models/dagcode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import os
import struct
from datetime import datetime, timedelta
from typing import Iterable, List

from sqlalchemy import BigInteger, Column, String, UnicodeText, and_, exists

from airflow.configuration import conf
from airflow.exceptions import AirflowException, DagCodeNotFound
from airflow.models import Base
from airflow.utils import timezone
from airflow.utils.file import correct_maybe_zipped, open_maybe_zipped
from airflow.utils.session import provide_session
from airflow.utils.sqlalchemy import UtcDateTime

log = logging.getLogger(__name__)


class DagCode(Base):
"""A table for DAGs code.
dag_code table contains code of DAG files synchronized by scheduler.
This feature is controlled by:
* ``[core] store_serialized_dags = True``: enable this feature
* ``[core] store_dag_code = True``: enable this feature
For details on dag serialization see SerializedDagModel
"""
__tablename__ = 'dag_code'

fileloc_hash = Column(
BigInteger, nullable=False, primary_key=True, autoincrement=False)
fileloc = Column(String(2000), nullable=False)
# The max length of fileloc exceeds the limit of indexing.
last_updated = Column(UtcDateTime, nullable=False)
source_code = Column(UnicodeText, nullable=False)

def __init__(self, full_filepath: str):
self.fileloc = full_filepath
self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc)
self.last_updated = timezone.utcnow()
self.source_code = DagCode._read_code(self.fileloc)

@classmethod
def _read_code(cls, fileloc: str):
with open_maybe_zipped(fileloc, 'r') as source:
source_code = source.read()
return source_code

@provide_session
def sync_to_db(self, session=None):
"""Writes code into database.
:param session: ORM Session
"""
self.bulk_sync_to_db([self.fileloc], session)

@classmethod
@provide_session
def bulk_sync_to_db(cls, filelocs: Iterable[str], session=None):
"""Writes code in bulk into database.
:param filelocs: file paths of DAGs to sync
:param session: ORM Session
"""
filelocs = set(filelocs)
filelocs_to_hashes = {
fileloc: DagCode.dag_fileloc_hash(fileloc) for fileloc in filelocs
}
existing_orm_dag_codes = (
session
.query(DagCode)
.filter(DagCode.fileloc_hash.in_(filelocs_to_hashes.values()))
.with_for_update(of=DagCode)
.all()
)
existing_orm_dag_codes_by_fileloc_hashes = {
orm.fileloc_hash: orm for orm in existing_orm_dag_codes
}
exisitng_orm_filelocs = {
orm.fileloc for orm in existing_orm_dag_codes_by_fileloc_hashes.values()
}
if not exisitng_orm_filelocs.issubset(filelocs):
conflicting_filelocs = exisitng_orm_filelocs.difference(filelocs)
hashes_to_filelocs = {
DagCode.dag_fileloc_hash(fileloc): fileloc for fileloc in filelocs
}
message = ""
for fileloc in conflicting_filelocs:
message += ("Filename '{}' causes a hash collision in the " +
"database with '{}'. Please rename the file.")\
.format(
hashes_to_filelocs[DagCode.dag_fileloc_hash(fileloc)],
fileloc)
raise AirflowException(message)

existing_filelocs = {
dag_code.fileloc for dag_code in existing_orm_dag_codes
}
missing_filelocs = filelocs.difference(existing_filelocs)

for fileloc in missing_filelocs:
orm_dag_code = DagCode(fileloc)
session.add(orm_dag_code)

for fileloc in existing_filelocs:
old_version = existing_orm_dag_codes_by_fileloc_hashes[
filelocs_to_hashes[fileloc]
]
file_modified = datetime.fromtimestamp(
os.path.getmtime(correct_maybe_zipped(fileloc)), tz=timezone.utc)

if (file_modified - timedelta(seconds=120)) > old_version.last_updated:
orm_dag_code.last_updated = timezone.utcnow()
orm_dag_code.source_code = DagCode._read_code(orm_dag_code.fileloc)
session.update(orm_dag_code)

@classmethod
@provide_session
def remove_deleted_code(cls, alive_dag_filelocs: List[str], session=None):
"""Deletes code not included in alive_dag_filelocs.
:param alive_dag_filelocs: file paths of alive DAGs
:param session: ORM Session
"""
alive_fileloc_hashes = [
cls.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs]

log.debug("Deleting code from %s table ", cls.__tablename__)

session.execute(
session.query(cls).filter(
and_(cls.fileloc_hash.notin_(alive_fileloc_hashes),
cls.fileloc.notin_(alive_dag_filelocs))).delete())

@classmethod
@provide_session
def has_dag(cls, fileloc: str, session=None) -> bool:
"""Checks a file exist in dag_code table.
:param fileloc: the file to check
:param session: ORM Session
"""
fileloc_hash = cls.dag_fileloc_hash(fileloc)
return session.query(exists().where(cls.fileloc_hash == fileloc_hash))\
.scalar()

@classmethod
def get_code_by_fileloc(cls, fileloc: str) -> str:
"""Returns source code for a given fileloc.
:param fileloc: file path of a DAG
:return: source code as string
"""
return DagCode(fileloc).code()

def code(self) -> str:
"""Returns source code for this DagCode object.
:return: source code as string
"""
if conf.getboolean('core', 'store_dag_code', fallback=False):
return self._get_code_from_db()
else:
return self._get_code_from_file()

def _get_code_from_file(self):
with open_maybe_zipped(self.fileloc, 'r') as f:
code = f.read()
return code

@provide_session
def _get_code_from_db(self, session=None):
dag_code = session.query(DagCode) \
.filter(DagCode.fileloc_hash == self.fileloc_hash) \
.first()
if not dag_code:
raise DagCodeNotFound()
else:
code = dag_code.source_code
return code

@staticmethod
def dag_fileloc_hash(full_filepath: str) -> int:
""""Hashing file location for indexing.
:param full_filepath: full filepath of DAG file
:return: hashed full_filepath
"""
# Hashing is needed because the length of fileloc is 2000 as an Airflow convention,
# which is over the limit of indexing.
import hashlib
# Only 7 bytes because MySQL BigInteger can hold only 8 bytes (signed).
return struct.unpack('>Q', hashlib.sha1(
full_filepath.encode('utf-8')).digest()[-8:])[0] >> 8
Loading

0 comments on commit e146518

Please sign in to comment.