Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue in pre-importing modules in zipfile #31061

Merged
merged 1 commit into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
import signal
import threading
import time
import zipfile
from contextlib import redirect_stderr, redirect_stdout, suppress
from datetime import datetime, timedelta
from multiprocessing.connection import Connection as MultiprocessingConnection
from typing import TYPE_CHECKING, Iterator
from typing import TYPE_CHECKING, Iterable, Iterator

from setproctitle import setproctitle
from sqlalchemy import exc, func, or_
Expand All @@ -51,7 +52,7 @@
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.email import get_email_address_list, send_email
from airflow.utils.file import iter_airflow_imports
from airflow.utils.file import iter_airflow_imports, might_contain_dag
from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, set_context
from airflow.utils.mixins import MultiprocessingStartMethodMixin
from airflow.utils.session import NEW_SESSION, provide_session
Expand Down Expand Up @@ -193,18 +194,23 @@ def start(self) -> None:
# Read the file to pre-import airflow modules used.
# This prevents them from being re-imported from zero in each "processing" process
# and saves CPU time and memory.
for module in iter_airflow_imports(self.file_path):
zip_file_paths = []
if zipfile.is_zipfile(self.file_path):
try:
importlib.import_module(module)
except Exception as e:
# only log as warning because an error here is not preventing anything from working, and
# if it's serious, it's going to be surfaced to the user when the dag is actually parsed.
self.log.warning(
"Error when trying to pre-import module '%s' found in %s: %s",
module,
self.file_path,
e,
)
with zipfile.ZipFile(self.file_path) as z:
zip_file_paths.extend(
[
os.path.join(self.file_path, info.filename)
for info in z.infolist()
if might_contain_dag(info.filename, True, z)
]
)
except zipfile.BadZipFile as err:
self.log.error("There was an err accessing %s, %s", self.file_path, err)
if zip_file_paths:
self.import_modules(zip_file_paths)
else:
self.import_modules(self.file_path)

context = self._get_multiprocessing_context()

Expand Down Expand Up @@ -355,6 +361,27 @@ def start_time(self) -> datetime:
def waitable_handle(self):
return self._process.sentinel

def import_modules(self, file_path: str | Iterable[str]):
def _import_modules(filepath):
for module in iter_airflow_imports(filepath):
try:
importlib.import_module(module)
except Exception as e:
# only log as warning because an error here is not preventing anything from working, and
# if it's serious, it's going to be surfaced to the user when the dag is actually parsed.
self.log.warning(
"Error when trying to pre-import module '%s' found in %s: %s",
module,
file_path,
e,
)

if isinstance(file_path, str):
_import_modules(file_path)
elif isinstance(file_path, Iterable):
for path in file_path:
_import_modules(path)


class DagFileProcessor(LoggingMixin):
"""
Expand Down
17 changes: 17 additions & 0 deletions tests/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,23 @@ def test_dag_parser_output_when_logging_to_file(self, mock_redirect_stdout_for_f
)
mock_redirect_stdout_for_file.assert_called_once()

@mock.patch("airflow.dag_processing.processor.settings.dispose_orm", MagicMock)
@mock.patch.object(DagFileProcessorProcess, "_get_multiprocessing_context")
def test_no_valueerror_with_parseable_dag_in_zip(self, mock_context, tmpdir):
mock_context.return_value.Pipe.return_value = (MagicMock(), MagicMock())
zip_filename = os.path.join(tmpdir, "test_zip.zip")
with ZipFile(zip_filename, "w") as zip_file:
zip_file.writestr(TEMP_DAG_FILENAME, PARSEABLE_DAG_FILE_CONTENTS)

processor = DagFileProcessorProcess(
file_path=zip_filename,
pickle_dags=False,
dag_ids=[],
dag_directory=[],
callback_requests=[],
)
processor.start()


class TestProcessorAgent:
@pytest.fixture(autouse=True)
Expand Down