Skip to content

Commit

Permalink
[sqllab] templating refactor (#1504)
Browse files Browse the repository at this point in the history
* Add support for jinja templates in WHERE/HAVING clauses

* Generalizing

* bugfix
  • Loading branch information
mistercrunch authored Nov 2, 2016
1 parent 0bab15b commit 1700a80
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 51 deletions.
74 changes: 34 additions & 40 deletions caravel/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,18 @@
from caravel.utils import CaravelTemplateException

config = app.config
BASE_CONTEXT = {
'datetime': datetime,
'random': random,
'relativedelta': relativedelta,
'time': time,
'timedelta': timedelta,
'uuid': uuid,
}
BASE_CONTEXT.update(config.get('JINJA_CONTEXT_ADDONS', {}))


class BaseContext(object):
class BaseTemplateProcessor(object):

"""Base class for database-specific jinja context
Expand All @@ -37,17 +46,31 @@ class BaseContext(object):
"""
engine = None

def __init__(self, database, query):
def __init__(self, database=None, query=None, table=None):
self.database = database
self.query = query
self.schema = None
if query and query.schema:
self.schema = query.schema
elif database:
self.schema = database.schema
elif table:
self.schema = table.schema
self.context = {}
self.context.update(BASE_CONTEXT)
if self.engine:
self.context[self.engine] = self

def process_template(self, sql):
"""Processes a sql template
>>> sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'"
>>> process_template(sql)
"SELECT '2017-01-01T00:00:00'"
"""
template = jinja2.Template(sql)
return template.render(self.context)


class PrestoContext(BaseContext):
class PrestoTemplateProcessor(BaseTemplateProcessor):
"""Presto Jinja context
The methods described here are namespaced under ``presto`` in the
Expand Down Expand Up @@ -170,43 +193,14 @@ def latest_sub_partition(self, table_name, **kwargs):
return df.to_dict()[field_to_return][0]


db_contexes = {}
template_processors = {}
keys = tuple(globals().keys())
for k in keys:
o = globals()[k]
if o and inspect.isclass(o) and issubclass(o, BaseContext):
db_contexes[o.engine] = o


def get_context(engine_name=None):
context = {
'datetime': datetime,
'random': random,
'relativedelta': relativedelta,
'time': time,
'timedelta': timedelta,
'uuid': uuid,
}
db_context = db_contexes.get(engine_name)
if engine_name and db_context:
context[engine_name] = db_context
return context


def process_template(sql, database=None, query=None):
"""Processes a sql template
>>> sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'"
>>> process_template(sql)
"SELECT '2017-01-01T00:00:00'"
"""
if o and inspect.isclass(o) and issubclass(o, BaseTemplateProcessor):
template_processors[o.engine] = o

context = get_context(database.backend if database else None)
template = jinja2.Template(sql)
backend = database.backend if database else None

# instantiating only the context for the active database
if context and backend in context:
context[backend] = context[backend](database, query)
context.update(config.get('JINJA_CONTEXT_ADDONS', {}))
return template.render(context)
def get_template_processor(database, table=None, query=None):
TP = template_processors.get(database.backend, BaseTemplateProcessor)
return TP(database=database, table=table, query=query)
20 changes: 13 additions & 7 deletions caravel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from caravel import app, db, db_engine_specs, get_session, utils, sm
from caravel.source_registry import SourceRegistry
from caravel.viz import viz_types
from caravel.jinja_context import process_template
from caravel.jinja_context import get_template_processor
from caravel.utils import (
flasher, MetricPermException, DimSelector, wrap_clause_in_parens
)
Expand Down Expand Up @@ -960,6 +960,9 @@ def query( # sqla
extras=None,
columns=None):
"""Querying any sqla table from this common interface"""
template_processor = get_template_processor(
table=self, database=self.database)

# For backward compatibility
if granularity not in self.dttm_cols:
granularity = self.main_dttm_col
Expand Down Expand Up @@ -1088,12 +1091,15 @@ def visit_column(element, compiler, **kw):
if op == 'not in':
cond = ~cond
where_clause_and.append(cond)
if extras and 'where' in extras:
where = wrap_clause_in_parens(process_template(extras['where'], self.database))
where_clause_and += [where]
if extras and 'having' in extras:
having = wrap_clause_in_parens(process_template(extras['having'], self.database))
having_clause_and += [having]
if extras:
where = extras.get('where')
if where:
where_clause_and += [wrap_clause_in_parens(
template_processor.process_template(where))]
having = extras.get('having')
if having:
having_clause_and += [wrap_clause_in_parens(
template_processor.process_template(having))]
if granularity:
qry = qry.where(and_(*(time_filter + where_clause_and)))
else:
Expand Down
7 changes: 4 additions & 3 deletions caravel/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from caravel import (
app, db, models, utils, dataframe, results_backend)
from caravel.db_engine_specs import LimitMethod
from caravel.jinja_context import process_template

from caravel.jinja_context import get_template_processor
QueryStatus = models.QueryStatus

celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG'))
Expand Down Expand Up @@ -101,7 +100,9 @@ def handle_error(msg):
query.limit_used = True
engine = database.get_sqla_engine(schema=query.schema)
try:
executed_sql = process_template(executed_sql, database, query)
template_processor = get_template_processor(
database=database, query=query)
executed_sql = template_processor.process_template(executed_sql)
except Exception as e:
logging.exception(e)
msg = "Template rendering failed: " + utils.error_msg_from_exception(e)
Expand Down
1 change: 1 addition & 0 deletions dev-reqs.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
codeclimate-test-reporter
coveralls
flake8
mock
mysqlclient
nose
Expand Down
4 changes: 3 additions & 1 deletion tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,10 @@ def test_extra_table_metadata(self):
'ab_permission_view/panoramix/'.format(**locals()))

def test_process_template(self):
maindb = self.get_main_database(db.session)
sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'"
rendered = jinja_context.process_template(sql)
tp = jinja_context.get_template_processor(database=maindb)
rendered = tp.process_template(sql)
self.assertEqual("SELECT '2017-01-01T00:00:00'", rendered)

def test_templated_sql_json(self):
Expand Down

0 comments on commit 1700a80

Please sign in to comment.