Skip to content

Commit

Permalink
Correctly set json_provider_class on Flask app so it uses our encoder (
Browse files Browse the repository at this point in the history
…#26554)

Setting `json_provider_class` where we did had no effect, as it turns
out `Flask()` sets `self.json = self.json_provider_class(self)`, so we
were setting it too late.

(cherry picked from commit 378dfbe)
  • Loading branch information
ashb authored and jedcunningham committed Sep 23, 2022
1 parent 1723b86 commit 5ae2ae5
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 29 deletions.
20 changes: 18 additions & 2 deletions airflow/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
# under the License.
from __future__ import annotations

import json
import logging
from datetime import date, datetime
from decimal import Decimal

from flask.json import JSONEncoder
from flask.json.provider import JSONProvider

from airflow.utils.timezone import convert_to_utc, is_naive

Expand All @@ -40,7 +41,7 @@
log = logging.getLogger(__name__)


class AirflowJsonEncoder(JSONEncoder):
class AirflowJsonEncoder(json.JSONEncoder):
"""Custom Airflow json encoder implementation."""

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -107,3 +108,18 @@ def safe_get_name(pod):
return {}

raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable")


class AirflowJsonProvider(JSONProvider):
"""JSON Provider for Flask app to use AirflowJsonEncoder."""

ensure_ascii: bool = True
sort_keys: bool = True

def dumps(self, obj, **kwargs):
kwargs.setdefault('ensure_ascii', self.ensure_ascii)
kwargs.setdefault('sort_keys', self.sort_keys)
return json.dumps(obj, **kwargs, cls=AirflowJsonEncoder)

def loads(self, s: str | bytes, **kwargs):
return json.loads(s, **kwargs)
5 changes: 3 additions & 2 deletions airflow/www/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from airflow.exceptions import AirflowConfigException, RemovedInAirflow3Warning
from airflow.logging_config import configure_logging
from airflow.models import import_all_models
from airflow.utils.json import AirflowJsonEncoder
from airflow.utils.json import AirflowJsonProvider
from airflow.www.extensions.init_appbuilder import init_appbuilder
from airflow.www.extensions.init_appbuilder_links import init_appbuilder_links
from airflow.www.extensions.init_dagbag import init_dagbag
Expand Down Expand Up @@ -109,7 +109,8 @@ def create_app(config=None, testing=False):
flask_app.config['SQLALCHEMY_ENGINE_OPTIONS'] = settings.prepare_engine_args()

# Configure the JSON encoder used by `|tojson` filter from Flask
flask_app.json_provider_class = AirflowJsonEncoder
flask_app.json_provider_class = AirflowJsonProvider
flask_app.json = AirflowJsonProvider(flask_app)

csrf.init_app(flask_app)

Expand Down
10 changes: 1 addition & 9 deletions airflow/www/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from urllib.parse import urlencode

import sqlalchemy as sqla
from flask import Response, request, url_for
from flask import request, url_for
from flask.helpers import flash
from flask_appbuilder.forms import FieldConverter
from flask_appbuilder.models.filters import BaseFilter
Expand All @@ -47,7 +47,6 @@
from airflow.utils import timezone
from airflow.utils.code_utils import get_python_source
from airflow.utils.helpers import alchemy_to_dict
from airflow.utils.json import AirflowJsonEncoder
from airflow.utils.state import State, TaskInstanceState
from airflow.www.forms import DateTimeWithTimezoneField
from airflow.www.widgets import AirflowDateTimePickerWidget
Expand Down Expand Up @@ -322,13 +321,6 @@ def epoch(dttm):
return (int(time.mktime(dttm.timetuple())) * 1000,)


def json_response(obj):
"""Returns a json response from a json serializable python object"""
return Response(
response=json.dumps(obj, indent=4, cls=AirflowJsonEncoder), status=200, mimetype="application/json"
)


def make_cache_key(*args, **kwargs):
"""Used by cache to get a unique key per URL"""
path = request.path
Expand Down
33 changes: 17 additions & 16 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from urllib.parse import parse_qsl, unquote, urlencode, urlparse

import configupdater
import flask.json
import lazy_object_proxy
import markupsafe
import nvd3
Expand Down Expand Up @@ -107,7 +108,7 @@
from airflow.ti_deps.dependencies_deps import RUNNING_DEPS, SCHEDULER_QUEUED_DEPS
from airflow.timetables.base import DataInterval, TimeRestriction
from airflow.timetables.interval import CronDataIntervalTimetable
from airflow.utils import json as utils_json, timezone, yaml
from airflow.utils import timezone, yaml
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.dag_edges import dag_edges
from airflow.utils.dates import infer_time_unit, scale_time_units
Expand Down Expand Up @@ -575,7 +576,7 @@ def health(self):
'latest_scheduler_heartbeat': latest_scheduler_heartbeat,
}

return wwwutils.json_response(payload)
return flask.json.jsonify(payload)

@expose('/home')
@auth.has_access(
Expand Down Expand Up @@ -856,7 +857,7 @@ def dag_stats(self, session=None):
filter_dag_ids = allowed_dag_ids

if not filter_dag_ids:
return wwwutils.json_response({})
return flask.json.jsonify({})

payload = {}
dag_state_stats = dag_state_stats.filter(dr.dag_id.in_(filter_dag_ids))
Expand All @@ -873,7 +874,7 @@ def dag_stats(self, session=None):
count = data.get(dag_id, {}).get(state, 0)
payload[dag_id].append({'state': state, 'count': count})

return wwwutils.json_response(payload)
return flask.json.jsonify(payload)

@expose('/task_stats', methods=['POST'])
@auth.has_access(
Expand All @@ -889,7 +890,7 @@ def task_stats(self, session=None):
allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)

if not allowed_dag_ids:
return wwwutils.json_response({})
return flask.json.jsonify({})

# Filter by post parameters
selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id}
Expand Down Expand Up @@ -983,7 +984,7 @@ def task_stats(self, session=None):
for state in State.task_states:
count = data.get(dag_id, {}).get(state, 0)
payload[dag_id].append({'state': state, 'count': count})
return wwwutils.json_response(payload)
return flask.json.jsonify(payload)

@expose('/last_dagruns', methods=['POST'])
@auth.has_access(
Expand All @@ -1006,7 +1007,7 @@ def last_dagruns(self, session=None):
filter_dag_ids = allowed_dag_ids

if not filter_dag_ids:
return wwwutils.json_response({})
return flask.json.jsonify({})

last_runs_subquery = (
session.query(
Expand Down Expand Up @@ -1046,7 +1047,7 @@ def last_dagruns(self, session=None):
}
for r in query
}
return wwwutils.json_response(resp)
return flask.json.jsonify(resp)

@expose('/code')
@auth.has_access(
Expand Down Expand Up @@ -2104,7 +2105,7 @@ def blocked(self, session=None):
filter_dag_ids = allowed_dag_ids

if not filter_dag_ids:
return wwwutils.json_response([])
return flask.json.jsonify([])

dags = (
session.query(DagRun.dag_id, sqla.func.count(DagRun.id))
Expand All @@ -2127,7 +2128,7 @@ def blocked(self, session=None):
'max_active_runs': max_active_runs,
}
)
return wwwutils.json_response(payload)
return flask.json.jsonify(payload)

def _mark_dagrun_state_as_failed(self, dag_id, dag_run_id, confirmed):
if not dag_run_id:
Expand Down Expand Up @@ -3410,7 +3411,7 @@ def task_instances(self):
for ti in dag.get_task_instances(dttm, dttm)
}

return json.dumps(task_instances, cls=utils_json.AirflowJsonEncoder)
return flask.json.jsonify(task_instances)

@expose('/object/grid_data')
@auth.has_access(
Expand Down Expand Up @@ -3465,7 +3466,7 @@ def grid_data(self):
}
# avoid spaces to reduce payload size
return (
htmlsafe_json_dumps(data, separators=(',', ':'), cls=utils_json.AirflowJsonEncoder),
htmlsafe_json_dumps(data, separators=(',', ':'), dumps=flask.json.dumps),
{'Content-Type': 'application/json; charset=utf-8'},
)

Expand Down Expand Up @@ -3508,7 +3509,7 @@ def next_run_datasets(self, dag_id):
.all()
]
return (
htmlsafe_json_dumps(data, separators=(',', ':'), cls=utils_json.AirflowJsonEncoder),
htmlsafe_json_dumps(data, separators=(',', ':'), dumps=flask.json.dumps),
{'Content-Type': 'application/json; charset=utf-8'},
)

Expand Down Expand Up @@ -3545,7 +3546,7 @@ def dataset_dependencies(self):
}

return (
htmlsafe_json_dumps(data, separators=(',', ':'), cls=utils_json.AirflowJsonEncoder),
htmlsafe_json_dumps(data, separators=(',', ':'), dumps=flask.json.dumps),
{'Content-Type': 'application/json; charset=utf-8'},
)

Expand Down Expand Up @@ -5205,7 +5206,7 @@ def autocomplete(self, session=None):
query = unquote(request.args.get('query', ''))

if not query:
return wwwutils.json_response([])
return flask.json.jsonify([])

# Provide suggestions of dag_ids and owners
dag_ids_query = session.query(
Expand Down Expand Up @@ -5239,7 +5240,7 @@ def autocomplete(self, session=None):
payload = [
row._asdict() for row in dag_ids_query.union(owners_query).order_by('name').limit(10).all()
]
return wwwutils.json_response(payload)
return flask.json.jsonify(payload)


class DagDependenciesView(AirflowBaseView):
Expand Down
10 changes: 10 additions & 0 deletions tests/www/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,13 @@ def test_flask_cli_should_display_routes(self, capsys):

output = capsys.readouterr()
assert "/login/" in output.out


def test_app_can_json_serialize_k8s_pod():
# This is mostly testing that we have correctly configured the JSON provider to use. Testing the k8s pos
# is a side-effect of that.
k8s = pytest.importorskip('kubernetes.client.models')

pod = k8s.V1Pod(spec=k8s.V1PodSpec(containers=[k8s.V1Container(name="base")]))
app = application.cached_app(testing=True)
assert app.json.dumps(pod) == '{"spec": {"containers": [{"name": "base"}]}}'

0 comments on commit 5ae2ae5

Please sign in to comment.