Skip to content

Commit

Permalink
[AIRFLOW-2156] Parallelize Celery Executor task state fetching (apach…
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinYang21 authored and Alice Berard committed Jan 3, 2019
1 parent e496be0 commit e6bfd5e
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 10 deletions.
5 changes: 5 additions & 0 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ so you might need to update your config.
The scheduler.min_file_parsing_loop_time config option has been temporarily removed due to
some bugs.

### new `sync_parallelism` config option in celery section

The new `sync_parallelism` config option will control how many processes CeleryExecutor will use to
fetch celery task state in parallel. Default value is max(1, number of cores - 1)

## Airflow 1.10

Installation and upgrading requires setting `SLUGIFY_USES_TEXT_UNIDECODE=yes` in your environment or
Expand Down
4 changes: 4 additions & 0 deletions airflow/config_templates/default_airflow.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,10 @@ flower_port = 5555
# Default queue that tasks get assigned to and that worker listen on.
default_queue = default

# How many processes CeleryExecutor uses to sync task state.
# 0 means to use max(1, number of cores - 1) processes.
sync_parallelism = 0

# Import path for celery configuration options
celery_config_options = airflow.config_templates.default_celery.DEFAULT_CELERY_CONFIG

Expand Down
1 change: 1 addition & 0 deletions airflow/config_templates/default_test.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ result_backend = db+mysql://airflow:airflow@localhost:3306/airflow
flower_host = 0.0.0.0
flower_port = 5555
default_queue = default
sync_parallelism = 0

[mesos]
master = localhost:5050
Expand Down
115 changes: 105 additions & 10 deletions airflow/executors/celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,26 @@
# specific language governing permissions and limitations
# under the License.

import math
import os
import subprocess
import time
import os
import traceback
from multiprocessing import Pool, cpu_count

from celery import Celery
from celery import states as celery_states

from airflow import configuration
from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG
from airflow.exceptions import AirflowException
from airflow.executors.base_executor import BaseExecutor
from airflow import configuration
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.module_loading import import_string

# Make it constant for unit test.
CELERY_FETCH_ERR_MSG_HEADER = 'Error fetching Celery task state'

'''
To start the celery worker, run the command:
airflow worker
Expand Down Expand Up @@ -63,6 +69,42 @@ def execute_command(command):
raise AirflowException('Celery command failed')


class ExceptionWithTraceback(object):
"""
Wrapper class used to propogate exceptions to parent processes from subprocesses.
:param exception: The exception to wrap
:type exception: Exception
:param traceback: The stacktrace to wrap
:type traceback: str
"""

def __init__(self, exception, exception_traceback):
self.exception = exception
self.traceback = exception_traceback


def fetch_celery_task_state(celery_task):
"""
Fetch and return the state of the given celery task. The scope of this function is
global so that it can be called by subprocesses in the pool.
:param celery_task: a tuple of the Celery task key and the async Celery object used
to fetch the task's state
:type celery_task: (str, celery.result.AsyncResult)
:return: a tuple of the Celery task key and the Celery state of the task
:rtype: (str, str)
"""

try:
# Accessing state property of celery task will make actual network request
# to get the current state of the task.
res = (celery_task[0], celery_task[1].state)
except Exception as e:
exception_traceback = "Celery Task ID: {}\n{}".format(celery_task[0],
traceback.format_exc())
res = ExceptionWithTraceback(e, exception_traceback)
return res


class CeleryExecutor(BaseExecutor):
"""
CeleryExecutor is recommended for production use of Airflow. It allows
Expand All @@ -72,10 +114,27 @@ class CeleryExecutor(BaseExecutor):
vast amounts of messages, while providing operations with the tools
required to maintain such a system.
"""
def start(self):

def __init__(self):
super(CeleryExecutor, self).__init__()

# Celery doesn't support querying the state of multiple tasks in parallel
# (which can become a bottleneck on bigger clusters) so we use
# a multiprocessing pool to speed this up.
# How many worker processes are created for checking celery task state.
self._sync_parallelism = configuration.getint('celery', 'SYNC_PARALLELISM')
if self._sync_parallelism == 0:
self._sync_parallelism = max(1, cpu_count() - 1)

self._sync_pool = None
self.tasks = {}
self.last_state = {}

def start(self):
self.log.debug(
'Starting Celery Executor using {} processes for syncing'.format(
self._sync_parallelism))

def execute_async(self, key, command,
queue=DEFAULT_CELERY_CONFIG['task_default_queue'],
executor_config=None):
Expand All @@ -85,11 +144,48 @@ def execute_async(self, key, command,
args=[command], queue=queue)
self.last_state[key] = celery_states.PENDING

def _num_tasks_per_process(self):
"""
How many Celery tasks should be sent to each worker process.
:return: Number of tasks that should be used per process
:rtype: int
"""
return max(1,
int(math.ceil(1.0 * len(self.tasks) / self._sync_parallelism)))

def sync(self):
self.log.debug("Inquiring about %s celery task(s)", len(self.tasks))
for key, task in list(self.tasks.items()):
num_processes = min(len(self.tasks), self._sync_parallelism)
if num_processes == 0:
self.log.debug("No task to query celery, skipping sync")
return

self.log.debug("Inquiring about %s celery task(s) using %s processes",
len(self.tasks), num_processes)

# Recreate the process pool each sync in case processes in the pool die
self._sync_pool = Pool(processes=num_processes)

# Use chunking instead of a work queue to reduce context switching since tasks are
# roughly uniform in size
chunksize = self._num_tasks_per_process()

self.log.debug("Waiting for inquiries to complete...")
task_keys_to_states = self._sync_pool.map(
fetch_celery_task_state,
self.tasks.items(),
chunksize=chunksize)
self._sync_pool.close()
self._sync_pool.join()
self.log.debug("Inquiries completed.")

for key_and_state in task_keys_to_states:
if isinstance(key_and_state, ExceptionWithTraceback):
self.log.error(
CELERY_FETCH_ERR_MSG_HEADER + ", ignoring it:{}\n{}\n".format(
key_and_state.exception, key_and_state.traceback))
continue
key, state = key_and_state
try:
state = task.state
if self.last_state[key] != state:
if state == celery_states.SUCCESS:
self.success(key)
Expand All @@ -104,11 +200,10 @@ def sync(self):
del self.tasks[key]
del self.last_state[key]
else:
self.log.info("Unexpected state: %s", state)
self.log.info("Unexpected state: " + state)
self.last_state[key] = state
except Exception as e:
self.log.error("Error syncing the celery executor, ignoring it:")
self.log.exception(e)
except Exception:
self.log.exception("Error syncing the Celery executor, ignoring it.")

def end(self, synchronous=False):
if synchronous:
Expand Down
1 change: 1 addition & 0 deletions scripts/ci/airflow_travis.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ broker_url = amqp://guest:guest@rabbitmq:5672/
result_backend = db+mysql://root@mysql/airflow
flower_port = 5555
default_queue = default
sync_parallelism = 0

[celery_broker_transport_options]
visibility_timeout = 21600
Expand Down
4 changes: 4 additions & 0 deletions scripts/ci/kubernetes/kube/configmaps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ data:
# Default queue that tasks get assigned to and that worker listen on.
default_queue = default
# How many processes CeleryExecutor uses to sync task state.
# 0 means to use max(1, number of cores - 1) processes.
sync_parallelism = 0
# Import path for celery configuration options
celery_config_options = airflow.config_templates.default_celery.DEFAULT_CELERY_CONFIG
Expand Down
23 changes: 23 additions & 0 deletions tests/executors/test_celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
# under the License.
import sys
import unittest
import mock
from celery.contrib.testing.worker import start_worker

from airflow.executors.celery_executor import CeleryExecutor
from airflow.executors.celery_executor import app
from airflow.executors.celery_executor import CELERY_FETCH_ERR_MSG_HEADER
from airflow.utils.state import State

# leave this it is used by the test worker
Expand Down Expand Up @@ -57,5 +59,26 @@ def test_celery_integration(self):
self.assertNotIn('success', executor.last_state)
self.assertNotIn('fail', executor.last_state)

def test_exception_propagation(self):
@app.task
def fake_celery_task():
return {}

mock_log = mock.MagicMock()
executor = CeleryExecutor()
executor._log = mock_log

executor.tasks = {'key': fake_celery_task()}
executor.sync()
mock_log.error.assert_called_once()
args, kwargs = mock_log.error.call_args_list[0]
log = args[0]
# Result of queuing is not a celery task but a dict,
# and it should raise AttributeError and then get propagated
# to the error log.
self.assertIn(CELERY_FETCH_ERR_MSG_HEADER, log)
self.assertIn('AttributeError', log)


if __name__ == '__main__':
unittest.main()

0 comments on commit e6bfd5e

Please sign in to comment.