From ab5c53e4327b35923fa0e427b4225057ab4746ac Mon Sep 17 00:00:00 2001 From: Tim Schilling Date: Fri, 20 Dec 2019 16:03:21 -0600 Subject: [PATCH] Support SQL Select and Explain actions for Postgres JSON fields. --- debug_toolbar/panels/sql/tracking.py | 8 +++++++- tests/models.py | 11 ++++++++++ tests/panels/test_sql.py | 30 ++++++++++++++++++++++++++++ tests/test_integration.py | 30 ++++++++++++++++++++++++++++ 4 files changed, 78 insertions(+), 1 deletion(-) diff --git a/debug_toolbar/panels/sql/tracking.py b/debug_toolbar/panels/sql/tracking.py index c16f2319f..75366802c 100644 --- a/debug_toolbar/panels/sql/tracking.py +++ b/debug_toolbar/panels/sql/tracking.py @@ -8,6 +8,11 @@ from debug_toolbar import settings as dt_settings from debug_toolbar.utils import get_stack, get_template_info, tidy_stacktrace +try: + from psycopg2._json import Json as PostgresJson +except ImportError: + PostgresJson = None + class SQLQueryTriggered(Exception): """Thrown when template panel triggers a query""" @@ -105,6 +110,8 @@ def _quote_params(self, params): return [self._quote_expr(p) for p in params] def _decode(self, param): + if PostgresJson and isinstance(param, PostgresJson): + return param.dumps(param.adapted) # If a sequence type, decode each element separately if isinstance(param, (tuple, list)): return [self._decode(element) for element in param] @@ -136,7 +143,6 @@ def _record(self, method, sql, params): _params = json.dumps(self._decode(params)) except TypeError: pass # object not JSON serializable - template_info = get_template_info() alias = getattr(self.db, "alias", "default") diff --git a/tests/models.py b/tests/models.py index 652bed98a..8904155a9 100644 --- a/tests/models.py +++ b/tests/models.py @@ -8,3 +8,14 @@ def __repr__(self): class Binary(models.Model): field = models.BinaryField() + + +try: + from django.contrib.postgres.fields import JSONField + + class PostgresJSON(models.Model): + field = JSONField() + + +except ImportError: + pass diff --git a/tests/panels/test_sql.py b/tests/panels/test_sql.py index b69099963..e1fd8ef42 100644 --- a/tests/panels/test_sql.py +++ b/tests/panels/test_sql.py @@ -11,6 +11,16 @@ from ..base import BaseTestCase +try: + from psycopg2._json import Json as PostgresJson +except ImportError: + PostgresJson = None + +if connection.vendor == "postgresql": + from ..models import PostgresJSON as PostgresJSONModel +else: + PostgresJSONModel = None + class SQLPanelTestCase(BaseTestCase): panel_id = "SQLPanel" @@ -120,6 +130,26 @@ def test_param_conversion(self): ('["Foo", true, false]', "[10, 1]", '["2017-12-22 16:07:01"]'), ) + @unittest.skipUnless( + connection.vendor == "postgresql", "Test valid only on PostgreSQL" + ) + def test_json_param_conversion(self): + self.assertEqual(len(self.panel._queries), 0) + + list(PostgresJSONModel.objects.filter(field__contains={"foo": "bar"})) + + response = self.panel.process_request(self.request) + self.panel.generate_stats(self.request, response) + + # ensure query was logged + self.assertEqual(len(self.panel._queries), 1) + self.assertEqual( + self.panel._queries[0][1]["params"], '["{\\"foo\\": \\"bar\\"}"]', + ) + self.assertIsInstance( + self.panel._queries[0][1]["raw_params"][0], PostgresJson, + ) + def test_binary_param_force_text(self): self.assertEqual(len(self.panel._queries), 0) diff --git a/tests/test_integration.py b/tests/test_integration.py index 0adbdb03c..94e5ac990 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -6,6 +6,7 @@ from django.contrib.staticfiles.testing import StaticLiveServerTestCase from django.core import signing from django.core.checks import Warning, run_checks +from django.db import connection from django.http import HttpResponse from django.template.loader import get_template from django.test import RequestFactory, SimpleTestCase, TestCase @@ -206,6 +207,35 @@ def test_sql_explain_checks_show_toolbar(self): ) self.assertEqual(response.status_code, 404) + @unittest.skipUnless( + connection.vendor == "postgresql", "Test valid only on PostgreSQL" + ) + def test_sql_explain_postgres_json_field(self): + url = "/__debug__/sql_explain/" + base_query = ( + 'SELECT * FROM "tests_postgresjson" WHERE "tests_postgresjson"."field" @>' + ) + query = base_query + """ '{"foo": "bar"}'""" + data = { + "sql": query, + "raw_sql": base_query + " %s", + "params": '["{\\"foo\\": \\"bar\\"}"]', + "alias": "default", + "duration": "0", + "hash": "2b7172eb2ac8e2a8d6f742f8a28342046e0d00ba", + } + response = self.client.post(url, data) + self.assertEqual(response.status_code, 200) + response = self.client.post(url, data, HTTP_X_REQUESTED_WITH="XMLHttpRequest") + self.assertEqual(response.status_code, 200) + with self.settings(INTERNAL_IPS=[]): + response = self.client.post(url, data) + self.assertEqual(response.status_code, 404) + response = self.client.post( + url, data, HTTP_X_REQUESTED_WITH="XMLHttpRequest" + ) + self.assertEqual(response.status_code, 404) + def test_sql_profile_checks_show_toolbar(self): url = "/__debug__/sql_profile/" data = {