Skip to content

Commit

Permalink
Allow using Airflow with Flask CLI (#9030)
Browse files Browse the repository at this point in the history
  • Loading branch information
mik-laj authored Jun 2, 2020
1 parent a6216a7 commit 87a4a0a
Show file tree
Hide file tree
Showing 19 changed files with 102 additions and 108 deletions.
6 changes: 3 additions & 3 deletions airflow/cli/commands/role_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
from tabulate import tabulate

from airflow.utils import cli as cli_utils
from airflow.www.app import cached_appbuilder
from airflow.www.app import cached_app


def roles_list(args):
"""Lists all existing roles"""
appbuilder = cached_appbuilder()
appbuilder = cached_app().appbuilder # pylint: disable=no-member
roles = appbuilder.sm.get_all_roles()
print("Existing roles:\n")
role_names = sorted([[r.name] for r in roles])
Expand All @@ -38,6 +38,6 @@ def roles_list(args):
@cli_utils.action_logging
def roles_create(args):
"""Creates new empty role in DB"""
appbuilder = cached_appbuilder()
appbuilder = cached_app().appbuilder # pylint: disable=no-member
for role_name in args.role:
appbuilder.sm.add_role(role_name)
4 changes: 2 additions & 2 deletions airflow/cli/commands/sync_perm_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
"""Sync permission command"""
from airflow.models import DagBag
from airflow.utils import cli as cli_utils
from airflow.www.app import cached_appbuilder
from airflow.www.app import cached_app


@cli_utils.action_logging
def sync_perm(args):
"""Updates permissions for existing roles and DAGs"""
appbuilder = cached_appbuilder()
appbuilder = cached_app().appbuilder # pylint: disable=no-member
print('Updating permission, view-menu for all existing roles')
appbuilder.sm.sync_roles()
print('Updating permission on all DAG views')
Expand Down
14 changes: 7 additions & 7 deletions airflow/cli/commands/user_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
from tabulate import tabulate

from airflow.utils import cli as cli_utils
from airflow.www.app import cached_appbuilder
from airflow.www.app import cached_app


def users_list(args):
"""Lists users at the command line"""
appbuilder = cached_appbuilder()
appbuilder = cached_app().appbuilder # pylint: disable=no-member
users = appbuilder.sm.get_all_users()
fields = ['id', 'username', 'email', 'first_name', 'last_name', 'roles']
users = [[user.__getattribute__(field) for field in fields] for user in users]
Expand All @@ -44,7 +44,7 @@ def users_list(args):
@cli_utils.action_logging
def users_create(args):
"""Creates new user in the DB"""
appbuilder = cached_appbuilder()
appbuilder = cached_app().appbuilder # pylint: disable=no-member
role = appbuilder.sm.find_role(args.role)
if not role:
valid_roles = appbuilder.sm.get_all_roles()
Expand Down Expand Up @@ -74,7 +74,7 @@ def users_create(args):
@cli_utils.action_logging
def users_delete(args):
"""Deletes user from DB"""
appbuilder = cached_appbuilder()
appbuilder = cached_app().appbuilder # pylint: disable=no-member

try:
user = next(u for u in appbuilder.sm.get_all_users()
Expand All @@ -98,7 +98,7 @@ def users_manage_role(args, remove=False):
raise SystemExit('Conflicting args: must supply either --username'
' or --email, but not both')

appbuilder = cached_appbuilder()
appbuilder = cached_app().appbuilder # pylint: disable=no-member
user = (appbuilder.sm.find_user(username=args.username) or
appbuilder.sm.find_user(email=args.email))
if not user:
Expand Down Expand Up @@ -136,7 +136,7 @@ def users_manage_role(args, remove=False):

def users_export(args):
"""Exports all users to the json file"""
appbuilder = cached_appbuilder()
appbuilder = cached_app().appbuilder # pylint: disable=no-member
users = appbuilder.sm.get_all_users()
fields = ['id', 'username', 'email', 'first_name', 'last_name', 'roles']

Expand Down Expand Up @@ -184,7 +184,7 @@ def users_import(args):


def _import_users(users_list): # pylint: disable=redefined-outer-name
appbuilder = cached_appbuilder()
appbuilder = cached_app().appbuilder # pylint: disable=no-member
users_created = []
users_updated = []

Expand Down
2 changes: 1 addition & 1 deletion airflow/cli/commands/webserver_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def webserver(args):
print(
"Starting the web server on port {0} and host {1}.".format(
args.port, args.hostname))
app, _ = create_app(testing=conf.getboolean('core', 'unit_test_mode'))
app = create_app(testing=conf.getboolean('core', 'unit_test_mode'))
app.run(debug=True, use_reloader=not app.config['TESTING'],
port=args.port, host=args.hostname,
ssl_context=(ssl_cert, ssl_key) if ssl_cert and ssl_key else None)
Expand Down
69 changes: 36 additions & 33 deletions airflow/www/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
import socket
from datetime import timedelta
from typing import Any, Optional
from typing import Optional
from urllib.parse import urlparse

import flask
Expand All @@ -39,15 +39,18 @@
from airflow.utils.json import AirflowJsonEncoder
from airflow.www.static_config import configure_manifest_files

app = None # type: Any
appbuilder = None # type: Optional[AppBuilder]
app: Optional[Flask] = None
csrf = CSRFProtect()

log = logging.getLogger(__name__)


def root_app(env, resp):
resp(b'404 Not Found', [('Content-Type', 'text/plain')])
return [b'Apache Airflow is not at this location']


def create_app(config=None, testing=False, app_name="Airflow"):
global app, appbuilder
app = Flask(__name__)
app.secret_key = conf.get('webserver', 'SECRET_KEY')

Expand All @@ -70,6 +73,31 @@ def create_app(config=None, testing=False, app_name="Airflow"):
app.json_encoder = AirflowJsonEncoder

csrf.init_app(app)

def apply_middlewares(flask_app: Flask):
# Apply DispatcherMiddleware
base_url = urlparse(conf.get('webserver', 'base_url'))[2]
if not base_url or base_url == '/':
base_url = ""
if base_url:
flask_app.wsgi_app = DispatcherMiddleware( # type: ignore
root_app,
mounts={base_url: flask_app.wsgi_app}
)

# Apply ProxyFix middleware
if conf.getboolean('webserver', 'ENABLE_PROXY_FIX'):
flask_app.wsgi_app = ProxyFix( # type: ignore
flask_app.wsgi_app,
x_for=conf.getint("webserver", "PROXY_FIX_X_FOR", fallback=1),
x_proto=conf.getint("webserver", "PROXY_FIX_X_PROTO", fallback=1),
x_host=conf.getint("webserver", "PROXY_FIX_X_HOST", fallback=1),
x_port=conf.getint("webserver", "PROXY_FIX_X_PORT", fallback=1),
x_prefix=conf.getint("webserver", "PROXY_FIX_X_PREFIX", fallback=1)
)

apply_middlewares(app)

db = SQLA()
db.session = settings.Session
db.init_app(app)
Expand Down Expand Up @@ -286,36 +314,11 @@ def apply_caching(response):
def make_session_permanent():
flask_session.permanent = True

return app, appbuilder


def root_app(env, resp):
resp(b'404 Not Found', [('Content-Type', 'text/plain')])
return [b'Apache Airflow is not at this location']
return app


def cached_app(config=None, testing=False):
global app, appbuilder
if not app or not appbuilder:
base_url = urlparse(conf.get('webserver', 'base_url'))[2]
if not base_url or base_url == '/':
base_url = ""

app, _ = create_app(config=config, testing=testing)
app = DispatcherMiddleware(root_app, {base_url: app})
if conf.getboolean('webserver', 'ENABLE_PROXY_FIX'):
app = ProxyFix(
app,
x_for=conf.getint("webserver", "PROXY_FIX_X_FOR", fallback=1),
x_proto=conf.getint("webserver", "PROXY_FIX_X_PROTO", fallback=1),
x_host=conf.getint("webserver", "PROXY_FIX_X_HOST", fallback=1),
x_port=conf.getint("webserver", "PROXY_FIX_X_PORT", fallback=1),
x_prefix=conf.getint("webserver", "PROXY_FIX_X_PREFIX", fallback=1)
)
global app
if not app:
app = create_app(config=config, testing=testing)
return app


def cached_appbuilder(config=None, testing=False):
global appbuilder
cached_app(config=config, testing=testing)
return appbuilder
7 changes: 3 additions & 4 deletions airflow/www/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
#

from flask import g
from flask import current_app, g
from flask_appbuilder.security.sqla import models as sqla_models
from flask_appbuilder.security.sqla.manager import SecurityManager
from sqlalchemy import and_, or_
Expand All @@ -26,7 +26,6 @@
from airflow.exceptions import AirflowException
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import provide_session
from airflow.www.app import appbuilder
from airflow.www.utils import CustomSQLAInterface

EXISTING_ROLES = {
Expand Down Expand Up @@ -250,8 +249,8 @@ def get_user_roles(user=None):
if user is None:
user = g.user
if user.is_anonymous:
public_role = appbuilder.config.get('AUTH_ROLE_PUBLIC')
return [appbuilder.security_manager.find_role(public_role)] \
public_role = current_app.appbuilder.config.get('AUTH_ROLE_PUBLIC')
return [current_app.appbuilder.security_manager.find_role(public_role)] \
if public_role else []
return user.roles

Expand Down
21 changes: 10 additions & 11 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import markdown
import sqlalchemy as sqla
from flask import (
Markup, Response, escape, flash, jsonify, make_response, redirect, render_template, request,
Markup, Response, current_app, escape, flash, jsonify, make_response, redirect, render_template, request,
session as flask_session, url_for,
)
from flask_appbuilder import BaseView, ModelView, expose, has_access, permission_name
Expand Down Expand Up @@ -72,7 +72,6 @@
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import State
from airflow.www import utils as wwwutils
from airflow.www.app import appbuilder
from airflow.www.decorators import action_logging, gzipped, has_dag_access
from airflow.www.forms import (
ConnectionForm, DagRunForm, DateTimeForm, DateTimeWithNumRunsForm, DateTimeWithNumRunsWithDagRunsForm,
Expand Down Expand Up @@ -270,7 +269,7 @@ def get_int_arg(value, default=0):
end = start + dags_per_page

# Get all the dag id the user could access
filter_dag_ids = appbuilder.sm.get_accessible_dag_ids()
filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()

with create_session() as session:
# read orm_dags from the db
Expand Down Expand Up @@ -368,7 +367,7 @@ def get_int_arg(value, default=0):
def dag_stats(self, session=None):
dr = models.DagRun

allowed_dag_ids = appbuilder.sm.get_accessible_dag_ids()
allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
if 'all_dags' in allowed_dag_ids:
allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]

Expand Down Expand Up @@ -416,7 +415,7 @@ def task_stats(self, session=None):
DagRun = models.DagRun
Dag = models.DagModel

allowed_dag_ids = set(appbuilder.sm.get_accessible_dag_ids())
allowed_dag_ids = set(current_app.appbuilder.sm.get_accessible_dag_ids())

if not allowed_dag_ids:
return wwwutils.json_response({})
Expand Down Expand Up @@ -512,7 +511,7 @@ def task_stats(self, session=None):
def last_dagruns(self, session=None):
DagRun = models.DagRun

allowed_dag_ids = appbuilder.sm.get_accessible_dag_ids()
allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()

if 'all_dags' in allowed_dag_ids:
allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]
Expand Down Expand Up @@ -1167,7 +1166,7 @@ def dagrun_clear(self):
@has_access
@provide_session
def blocked(self, session=None):
allowed_dag_ids = appbuilder.sm.get_accessible_dag_ids()
allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()

if 'all_dags' in allowed_dag_ids:
allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]
Expand Down Expand Up @@ -1912,7 +1911,7 @@ def refresh(self, session=None):

dag = dagbag.get_dag(dag_id)
# sync dag permission
appbuilder.sm.sync_perm_for_dag(dag_id, dag.access_control)
current_app.appbuilder.sm.sync_perm_for_dag(dag_id, dag.access_control)

flash("DAG [{}] is now fresh as a daisy".format(dag_id))
return redirect(request.referrer)
Expand Down Expand Up @@ -2163,9 +2162,9 @@ def conf(self):

class DagFilter(BaseFilter):
def apply(self, query, func): # noqa
if appbuilder.sm.has_all_dags_access():
if current_app.appbuilder.sm.has_all_dags_access():
return query
filter_dag_ids = appbuilder.sm.get_accessible_dag_ids()
filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
return query.filter(self.model.dag_id.in_(filter_dag_ids))


Expand Down Expand Up @@ -2800,7 +2799,7 @@ def autocomplete(self, session=None):
dag_ids_query = dag_ids_query.filter(DagModel.is_paused)
owners_query = owners_query.filter(DagModel.is_paused)

filter_dag_ids = appbuilder.sm.get_accessible_dag_ids()
filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
if 'all_dags' not in filter_dag_ids:
dag_ids_query = dag_ids_query.filter(DagModel.dag_id.in_(filter_dag_ids))
owners_query = owners_query.filter(DagModel.dag_id.in_(filter_dag_ids))
Expand Down
5 changes: 1 addition & 4 deletions tests/cli/commands/test_celery_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
from airflow.configuration import conf
from tests.test_utils.config import conf_vars

mock.patch('airflow.utils.cli.action_logging', lambda x: x).start()
mock_args = Namespace(queues=1, concurrency=1)


class TestWorkerPrecheck(unittest.TestCase):
@mock.patch('airflow.settings.validate_session')
Expand All @@ -42,7 +39,7 @@ def test_error(self, mock_validate_session):
"""
mock_validate_session.return_value = False
with self.assertRaises(SystemExit) as cm:
celery_command.worker(mock_args)
celery_command.worker(Namespace(queues=1, concurrency=1))
self.assertEqual(cm.exception.code, 1)

@conf_vars({('core', 'worker_precheck'): 'False'})
Expand Down
3 changes: 2 additions & 1 deletion tests/cli/commands/test_role_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def setUpClass(cls):

def setUp(self):
from airflow.www import app as application
self.app, self.appbuilder = application.create_app(testing=True)
self.app = application.create_app(testing=True)
self.appbuilder = self.app.appbuilder # pylint: disable=no-member
self.clear_roles_and_roles()

def tearDown(self):
Expand Down
Loading

0 comments on commit 87a4a0a

Please sign in to comment.