From 608f232accdfbe1bd1e2c420be71c675e6e1453d Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 21 Sep 2022 13:14:50 +0100 Subject: [PATCH] Correctly set json_provider_class on Flask app so it uses our encoder Setting `json_provider_class` where we did had no effect, as it turns out `Flask()` sets `self.json = self.json_provider_class(self)`, so we were setting it too late. --- airflow/utils/json.py | 20 ++++++++++++++++++-- airflow/www/app.py | 5 +++-- airflow/www/utils.py | 10 +--------- airflow/www/views.py | 33 +++++++++++++++++---------------- tests/www/test_app.py | 10 ++++++++++ 5 files changed, 49 insertions(+), 29 deletions(-) diff --git a/airflow/utils/json.py b/airflow/utils/json.py index fcc4eedd6e7f4..ff1109782418d 100644 --- a/airflow/utils/json.py +++ b/airflow/utils/json.py @@ -17,11 +17,12 @@ # under the License. from __future__ import annotations +import json import logging from datetime import date, datetime from decimal import Decimal -from flask.json import JSONEncoder +from flask.json.provider import JSONProvider from airflow.utils.timezone import convert_to_utc, is_naive @@ -40,7 +41,7 @@ log = logging.getLogger(__name__) -class AirflowJsonEncoder(JSONEncoder): +class AirflowJsonEncoder(json.JSONEncoder): """Custom Airflow json encoder implementation.""" def __init__(self, *args, **kwargs): @@ -107,3 +108,18 @@ def safe_get_name(pod): return {} raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable") + + +class AirflowJsonProvider(JSONProvider): + """JSON Provider for Flask app to use AirflowJsonEncoder.""" + + ensure_ascii: bool = True + sort_keys: bool = True + + def dumps(self, obj, **kwargs): + kwargs.setdefault('ensure_ascii', self.ensure_ascii) + kwargs.setdefault('sort_keys', self.sort_keys) + return json.dumps(obj, **kwargs, cls=AirflowJsonEncoder) + + def loads(self, s: str | bytes, **kwargs): + return json.loads(s, **kwargs) diff --git a/airflow/www/app.py b/airflow/www/app.py index d40f3badb8c6b..b67314c99a8e9 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -32,7 +32,7 @@ from airflow.exceptions import AirflowConfigException, RemovedInAirflow3Warning from airflow.logging_config import configure_logging from airflow.models import import_all_models -from airflow.utils.json import AirflowJsonEncoder +from airflow.utils.json import AirflowJsonProvider from airflow.www.extensions.init_appbuilder import init_appbuilder from airflow.www.extensions.init_appbuilder_links import init_appbuilder_links from airflow.www.extensions.init_dagbag import init_dagbag @@ -109,7 +109,8 @@ def create_app(config=None, testing=False): flask_app.config['SQLALCHEMY_ENGINE_OPTIONS'] = settings.prepare_engine_args() # Configure the JSON encoder used by `|tojson` filter from Flask - flask_app.json_provider_class = AirflowJsonEncoder + flask_app.json_provider_class = AirflowJsonProvider + flask_app.json = AirflowJsonProvider(flask_app) csrf.init_app(flask_app) diff --git a/airflow/www/utils.py b/airflow/www/utils.py index d0efa611d5dd4..d49b73249717b 100644 --- a/airflow/www/utils.py +++ b/airflow/www/utils.py @@ -24,7 +24,7 @@ from urllib.parse import urlencode import sqlalchemy as sqla -from flask import Response, request, url_for +from flask import request, url_for from flask.helpers import flash from flask_appbuilder.forms import FieldConverter from flask_appbuilder.models.filters import BaseFilter @@ -47,7 +47,6 @@ from airflow.utils import timezone from airflow.utils.code_utils import get_python_source from airflow.utils.helpers import alchemy_to_dict -from airflow.utils.json import AirflowJsonEncoder from airflow.utils.state import State, TaskInstanceState from airflow.www.forms import DateTimeWithTimezoneField from airflow.www.widgets import AirflowDateTimePickerWidget @@ -322,13 +321,6 @@ def epoch(dttm): return (int(time.mktime(dttm.timetuple())) * 1000,) -def json_response(obj): - """Returns a json response from a json serializable python object""" - return Response( - response=json.dumps(obj, indent=4, cls=AirflowJsonEncoder), status=200, mimetype="application/json" - ) - - def make_cache_key(*args, **kwargs): """Used by cache to get a unique key per URL""" path = request.path diff --git a/airflow/www/views.py b/airflow/www/views.py index 2c6701505cd83..1aa6acb8b21e7 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -37,6 +37,7 @@ from urllib.parse import parse_qsl, unquote, urlencode, urlparse import configupdater +import flask.json import lazy_object_proxy import markupsafe import nvd3 @@ -107,7 +108,7 @@ from airflow.ti_deps.dependencies_deps import RUNNING_DEPS, SCHEDULER_QUEUED_DEPS from airflow.timetables.base import DataInterval, TimeRestriction from airflow.timetables.interval import CronDataIntervalTimetable -from airflow.utils import json as utils_json, timezone, yaml +from airflow.utils import timezone, yaml from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.dag_edges import dag_edges from airflow.utils.dates import infer_time_unit, scale_time_units @@ -575,7 +576,7 @@ def health(self): 'latest_scheduler_heartbeat': latest_scheduler_heartbeat, } - return wwwutils.json_response(payload) + return flask.json.jsonify(payload) @expose('/home') @auth.has_access( @@ -856,7 +857,7 @@ def dag_stats(self, session=None): filter_dag_ids = allowed_dag_ids if not filter_dag_ids: - return wwwutils.json_response({}) + return flask.json.jsonify({}) payload = {} dag_state_stats = dag_state_stats.filter(dr.dag_id.in_(filter_dag_ids)) @@ -873,7 +874,7 @@ def dag_stats(self, session=None): count = data.get(dag_id, {}).get(state, 0) payload[dag_id].append({'state': state, 'count': count}) - return wwwutils.json_response(payload) + return flask.json.jsonify(payload) @expose('/task_stats', methods=['POST']) @auth.has_access( @@ -889,7 +890,7 @@ def task_stats(self, session=None): allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) if not allowed_dag_ids: - return wwwutils.json_response({}) + return flask.json.jsonify({}) # Filter by post parameters selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id} @@ -983,7 +984,7 @@ def task_stats(self, session=None): for state in State.task_states: count = data.get(dag_id, {}).get(state, 0) payload[dag_id].append({'state': state, 'count': count}) - return wwwutils.json_response(payload) + return flask.json.jsonify(payload) @expose('/last_dagruns', methods=['POST']) @auth.has_access( @@ -1006,7 +1007,7 @@ def last_dagruns(self, session=None): filter_dag_ids = allowed_dag_ids if not filter_dag_ids: - return wwwutils.json_response({}) + return flask.json.jsonify({}) last_runs_subquery = ( session.query( @@ -1046,7 +1047,7 @@ def last_dagruns(self, session=None): } for r in query } - return wwwutils.json_response(resp) + return flask.json.jsonify(resp) @expose('/code') @auth.has_access( @@ -2106,7 +2107,7 @@ def blocked(self, session=None): filter_dag_ids = allowed_dag_ids if not filter_dag_ids: - return wwwutils.json_response([]) + return flask.json.jsonify([]) dags = ( session.query(DagRun.dag_id, sqla.func.count(DagRun.id)) @@ -2129,7 +2130,7 @@ def blocked(self, session=None): 'max_active_runs': max_active_runs, } ) - return wwwutils.json_response(payload) + return flask.json.jsonify(payload) def _mark_dagrun_state_as_failed(self, dag_id, dag_run_id, confirmed): if not dag_run_id: @@ -3412,7 +3413,7 @@ def task_instances(self): for ti in dag.get_task_instances(dttm, dttm) } - return json.dumps(task_instances, cls=utils_json.AirflowJsonEncoder) + return flask.json.jsonify(task_instances) @expose('/object/grid_data') @auth.has_access( @@ -3467,7 +3468,7 @@ def grid_data(self): } # avoid spaces to reduce payload size return ( - htmlsafe_json_dumps(data, separators=(',', ':'), cls=utils_json.AirflowJsonEncoder), + htmlsafe_json_dumps(data, separators=(',', ':'), dumps=flask.json.dumps), {'Content-Type': 'application/json; charset=utf-8'}, ) @@ -3510,7 +3511,7 @@ def next_run_datasets(self, dag_id): .all() ] return ( - htmlsafe_json_dumps(data, separators=(',', ':'), cls=utils_json.AirflowJsonEncoder), + htmlsafe_json_dumps(data, separators=(',', ':'), dumps=flask.json.dumps), {'Content-Type': 'application/json; charset=utf-8'}, ) @@ -3547,7 +3548,7 @@ def dataset_dependencies(self): } return ( - htmlsafe_json_dumps(data, separators=(',', ':'), cls=utils_json.AirflowJsonEncoder), + htmlsafe_json_dumps(data, separators=(',', ':'), dumps=flask.json.dumps), {'Content-Type': 'application/json; charset=utf-8'}, ) @@ -5207,7 +5208,7 @@ def autocomplete(self, session=None): query = unquote(request.args.get('query', '')) if not query: - return wwwutils.json_response([]) + return flask.json.jsonify([]) # Provide suggestions of dag_ids and owners dag_ids_query = session.query( @@ -5241,7 +5242,7 @@ def autocomplete(self, session=None): payload = [ row._asdict() for row in dag_ids_query.union(owners_query).order_by('name').limit(10).all() ] - return wwwutils.json_response(payload) + return flask.json.jsonify(payload) class DagDependenciesView(AirflowBaseView): diff --git a/tests/www/test_app.py b/tests/www/test_app.py index e62bda71d07cb..d82dda1d7ae34 100644 --- a/tests/www/test_app.py +++ b/tests/www/test_app.py @@ -240,3 +240,13 @@ def test_flask_cli_should_display_routes(self, capsys): output = capsys.readouterr() assert "/login/" in output.out + + +def test_app_can_json_serialize_k8s_pod(): + # This is mostly testing that we have correctly configured the JSON provider to use. Testing the k8s pos + # is a side-effect of that. + k8s = pytest.importorskip('kubernetes.client.models') + + pod = k8s.V1Pod(spec=k8s.V1PodSpec(containers=[k8s.V1Container(name="base")])) + app = application.cached_app(testing=True) + assert app.json.dumps(pod) == '{"spec": {"containers": [{"name": "base"}]}}'