Skip to content

Commit

Permalink
Merge pull request #1229 from tim-schilling/postgres-json-explain
Browse files Browse the repository at this point in the history
Support SQL Select and Explain actions for Postgres JSON fields.
  • Loading branch information
matthiask authored Jan 31, 2020
2 parents 98308a2 + ab5c53e commit 852e455
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 1 deletion.
8 changes: 7 additions & 1 deletion debug_toolbar/panels/sql/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from debug_toolbar import settings as dt_settings
from debug_toolbar.utils import get_stack, get_template_info, tidy_stacktrace

try:
from psycopg2._json import Json as PostgresJson
except ImportError:
PostgresJson = None


class SQLQueryTriggered(Exception):
"""Thrown when template panel triggers a query"""
Expand Down Expand Up @@ -105,6 +110,8 @@ def _quote_params(self, params):
return [self._quote_expr(p) for p in params]

def _decode(self, param):
if PostgresJson and isinstance(param, PostgresJson):
return param.dumps(param.adapted)
# If a sequence type, decode each element separately
if isinstance(param, (tuple, list)):
return [self._decode(element) for element in param]
Expand Down Expand Up @@ -136,7 +143,6 @@ def _record(self, method, sql, params):
_params = json.dumps(self._decode(params))
except TypeError:
pass # object not JSON serializable

template_info = get_template_info()

alias = getattr(self.db, "alias", "default")
Expand Down
11 changes: 11 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,14 @@ def __repr__(self):

class Binary(models.Model):
field = models.BinaryField()


try:
from django.contrib.postgres.fields import JSONField

class PostgresJSON(models.Model):
field = JSONField()


except ImportError:
pass
30 changes: 30 additions & 0 deletions tests/panels/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@

from ..base import BaseTestCase

try:
from psycopg2._json import Json as PostgresJson
except ImportError:
PostgresJson = None

if connection.vendor == "postgresql":
from ..models import PostgresJSON as PostgresJSONModel
else:
PostgresJSONModel = None


class SQLPanelTestCase(BaseTestCase):
panel_id = "SQLPanel"
Expand Down Expand Up @@ -120,6 +130,26 @@ def test_param_conversion(self):
('["Foo", true, false]', "[10, 1]", '["2017-12-22 16:07:01"]'),
)

@unittest.skipUnless(
connection.vendor == "postgresql", "Test valid only on PostgreSQL"
)
def test_json_param_conversion(self):
self.assertEqual(len(self.panel._queries), 0)

list(PostgresJSONModel.objects.filter(field__contains={"foo": "bar"}))

response = self.panel.process_request(self.request)
self.panel.generate_stats(self.request, response)

# ensure query was logged
self.assertEqual(len(self.panel._queries), 1)
self.assertEqual(
self.panel._queries[0][1]["params"], '["{\\"foo\\": \\"bar\\"}"]',
)
self.assertIsInstance(
self.panel._queries[0][1]["raw_params"][0], PostgresJson,
)

def test_binary_param_force_text(self):
self.assertEqual(len(self.panel._queries), 0)

Expand Down
30 changes: 30 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from django.contrib.staticfiles.testing import StaticLiveServerTestCase
from django.core import signing
from django.core.checks import Warning, run_checks
from django.db import connection
from django.http import HttpResponse
from django.template.loader import get_template
from django.test import RequestFactory, SimpleTestCase, TestCase
Expand Down Expand Up @@ -206,6 +207,35 @@ def test_sql_explain_checks_show_toolbar(self):
)
self.assertEqual(response.status_code, 404)

@unittest.skipUnless(
connection.vendor == "postgresql", "Test valid only on PostgreSQL"
)
def test_sql_explain_postgres_json_field(self):
url = "/__debug__/sql_explain/"
base_query = (
'SELECT * FROM "tests_postgresjson" WHERE "tests_postgresjson"."field" @>'
)
query = base_query + """ '{"foo": "bar"}'"""
data = {
"sql": query,
"raw_sql": base_query + " %s",
"params": '["{\\"foo\\": \\"bar\\"}"]',
"alias": "default",
"duration": "0",
"hash": "2b7172eb2ac8e2a8d6f742f8a28342046e0d00ba",
}
response = self.client.post(url, data)
self.assertEqual(response.status_code, 200)
response = self.client.post(url, data, HTTP_X_REQUESTED_WITH="XMLHttpRequest")
self.assertEqual(response.status_code, 200)
with self.settings(INTERNAL_IPS=[]):
response = self.client.post(url, data)
self.assertEqual(response.status_code, 404)
response = self.client.post(
url, data, HTTP_X_REQUESTED_WITH="XMLHttpRequest"
)
self.assertEqual(response.status_code, 404)

def test_sql_profile_checks_show_toolbar(self):
url = "/__debug__/sql_profile/"
data = {
Expand Down

0 comments on commit 852e455

Please sign in to comment.