From 0ca3f5ec800b5543747a00f24377a86a6f017226 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Mon, 2 May 2016 10:00:28 -0700 Subject: [PATCH] Improving SQLA query generation (#421) * Improving SQLA query generation * Fixing debug --- caravel/data/__init__.py | 2 +- ...9e13_adding_verbose_name_to_tablecolumn.py | 25 +++++++++ caravel/models.py | 51 +++++++++++-------- caravel/views.py | 4 +- tests/core_tests.py | 7 +-- 5 files changed, 62 insertions(+), 27 deletions(-) create mode 100644 caravel/migrations/versions/f0fbf6129e13_adding_verbose_name_to_tablecolumn.py diff --git a/caravel/data/__init__.py b/caravel/data/__init__.py index 7b68670287f8a..69ad02cb3d2f3 100644 --- a/caravel/data/__init__.py +++ b/caravel/data/__init__.py @@ -372,7 +372,7 @@ def load_world_bank_health_n_pop(): merge_slice(slc) print("Creating a World's Health Bank dashboard") - dash_name = "World's Health Bank Dashboard" + dash_name = "World's Bank Data" dash = db.session.query(Dash).filter_by(dashboard_title=dash_name).first() if not dash: diff --git a/caravel/migrations/versions/f0fbf6129e13_adding_verbose_name_to_tablecolumn.py b/caravel/migrations/versions/f0fbf6129e13_adding_verbose_name_to_tablecolumn.py new file mode 100644 index 0000000000000..51f4923b9cedf --- /dev/null +++ b/caravel/migrations/versions/f0fbf6129e13_adding_verbose_name_to_tablecolumn.py @@ -0,0 +1,25 @@ +"""Adding verbose_name to tablecolumn + +Revision ID: f0fbf6129e13 +Revises: c3a8f8611885 +Create Date: 2016-05-01 12:21:18.331191 + +""" + +# revision identifiers, used by Alembic. +revision = 'f0fbf6129e13' +down_revision = 'c3a8f8611885' + +from alembic import op +import sqlalchemy as sa + + +def upgrade(): + op.add_column( + 'table_columns', + sa.Column('verbose_name', sa.String(length=1024), + nullable=True)) + + +def downgrade(): + op.drop_column('table_columns', 'verbose_name') diff --git a/caravel/models.py b/caravel/models.py index 619682a12be4a..928f1075a706a 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -564,15 +564,15 @@ def query( # sqla "and is required by this type of chart") metrics_exprs = [ - literal_column(m.expression).label(m.metric_name) + m.sqla_col for m in self.metrics if m.metric_name in metrics] if metrics: - main_metric_expr = literal_column([ - m.expression for m in self.metrics - if m.metric_name == metrics[0]][0]) + main_metric_expr = [ + m.sqla_col for m in self.metrics + if m.metric_name == metrics[0]][0] else: - main_metric_expr = literal_column("COUNT(*)") + main_metric_expr = literal_column("COUNT(*)").label("ccount") select_exprs = [] groupby_exprs = [] @@ -583,13 +583,8 @@ def query( # sqla inner_groupby_exprs = [] for s in groupby: col = cols[s] - expr = col.expression - if expr: - outer = literal_column(expr).label(s) - inner = literal_column(expr).label('__' + s) - else: - outer = column(s).label(s) - inner = column(s).label('__' + s) + outer = col.sqla_col + inner = col.sqla_col.label('__' + col.column_name) groupby_exprs.append(outer) select_exprs.append(outer) @@ -597,12 +592,12 @@ def query( # sqla inner_select_exprs.append(inner) elif columns: for s in columns: - select_exprs.append(s) + select_exprs.append(cols[s].sqla_col) metrics_exprs = [] if granularity: - dttm_expr = cols[granularity].expression or granularity - timestamp = literal_column(dttm_expr).label('timestamp') + dttm_expr = cols[granularity].sqla_col.label('timestamp') + timestamp = dttm_expr # Transforming time grain into an expression based on configuration time_grain_sqla = extras.get('time_grain_sqla') @@ -646,11 +641,7 @@ def query( # sqla col_obj = cols[col] if op in ('in', 'not in'): values = eq.split(",") - if col_obj.expression: - cond = ColumnClause( - col_obj.expression, is_literal=True).in_(values) - else: - cond = column(col).in_(values) + cond = col_obj.sqla_col.in_(values) if op == 'not in': cond = ~cond where_clause_and.append(cond) @@ -685,7 +676,10 @@ def query( # sqla engine = self.database.get_sqla_engine() sql = "{}".format( - qry.compile(engine, compile_kwargs={"literal_binds": True})) + qry.compile( + engine, compile_kwargs={"literal_binds": True},), + ) + print(sql) df = pd.read_sql_query( sql=sql, con=engine @@ -811,6 +805,11 @@ class SqlMetric(Model, AuditMixinNullable): expression = Column(Text) description = Column(Text) + @property + def sqla_col(self): + name = self.metric_name + return literal_column(self.expression).label(name) + class TableColumn(Model, AuditMixinNullable): @@ -822,6 +821,7 @@ class TableColumn(Model, AuditMixinNullable): table = relationship( 'SqlaTable', backref='columns', foreign_keys=[table_id]) column_name = Column(String(256)) + verbose_name = Column(String(1024)) is_dttm = Column(Boolean, default=False) is_active = Column(Boolean, default=True) type = Column(String(32), default='') @@ -842,6 +842,15 @@ def isnum(self): types = ('LONG', 'DOUBLE', 'FLOAT', 'BIGINT', 'INT') return any([t in self.type.upper() for t in types]) + @property + def sqla_col(self): + name = self.column_name + if not self.expression: + col = column(self.column_name).label(name) + else: + col = literal_column(self.expression).label(name) + return col + class DruidCluster(Model, AuditMixinNullable): diff --git a/caravel/views.py b/caravel/views.py index 6d6ef5d7d7d7e..69ee8e918adfd 100644 --- a/caravel/views.py +++ b/caravel/views.py @@ -102,8 +102,8 @@ class TableColumnInlineView(CompactCRUDMixin, CaravelModelView): # noqa datamodel = SQLAInterface(models.TableColumn) can_delete = False edit_columns = [ - 'column_name', 'description', 'groupby', 'filterable', 'table', - 'count_distinct', 'sum', 'min', 'max', 'expression', 'is_dttm'] + 'column_name', 'verbose_name', 'description', 'groupby', 'filterable', + 'table', 'count_distinct', 'sum', 'min', 'max', 'expression', 'is_dttm'] add_columns = edit_columns list_columns = [ 'column_name', 'type', 'groupby', 'filterable', 'count_distinct', diff --git a/tests/core_tests.py b/tests/core_tests.py index b25c85799cb77..d8e527808e57b 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -113,10 +113,11 @@ def test_slices(self): urls = [] for slc in db.session.query(Slc).all(): urls += [ - slc.slice_url, - slc.viz.json_endpoint, + (slc.slice_name, slc.slice_url), + (slc.slice_name, slc.viz.json_endpoint), ] - for url in urls: + for name, url in urls: + print("Slice: " + name) self.client.get(url) def test_dashboard(self):