Skip to content

Commit

Permalink
Support authentication against remote APIs for PandasConnector - see
Browse files Browse the repository at this point in the history
  • Loading branch information
rhunwicks committed Dec 4, 2017
1 parent c892e50 commit 0a24cb6
Show file tree
Hide file tree
Showing 8 changed files with 313 additions and 95 deletions.
2 changes: 1 addition & 1 deletion contrib/cache/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def set(self, key, value, timeout=None):
value.to_hdf(tmp, 'df')
metadata['format'] = 'hdf'
metadata['read_args'] = {'key': 'df'}
except ImportError:
except Exception:
# PyTables is not installed, so fallback to pickle
pickle.dump(value, f, pickle.HIGHEST_PROTOCOL)
metadata['format'] = 'pickle'
Expand Down
192 changes: 124 additions & 68 deletions contrib/connectors/pandas/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from datetime import datetime
import hashlib
import logging
from io import BytesIO
from past.builtins import basestring
import requests
try:
from urllib.parse import urlparse
except ImportError:
Expand All @@ -18,7 +20,7 @@
is_string_dtype, is_numeric_dtype, is_datetime64_any_dtype)

from sqlalchemy import (
Column, Integer, String, ForeignKey, Text, or_
Column, Integer, String, ForeignKey, Text, and_, or_
)
import sqlalchemy as sa
from sqlalchemy.orm import backref, relationship
Expand Down Expand Up @@ -105,6 +107,14 @@ def data(self):
'filterable', 'groupby')
return {s: getattr(self, s) for s in attrs}

def get_perm(self):
if self.datasource:
return ('{parent_name}.[{obj.expression}]'
'(id:{obj.id})').format(
obj=self,
parent_name=self.datasource.full_name)
return None


class PandasMetric(Model, BaseMetric):
"""
Expand All @@ -124,8 +134,7 @@ class PandasMetric(Model, BaseMetric):
source = Column(Text)
expression = Column(Text)

@property
def perm(self):
def get_perm(self):
if self.datasource:
return ('{parent_name}.[{obj.metric_name}]'
'(id:{obj.id})').format(
Expand Down Expand Up @@ -165,6 +174,8 @@ class PandasDatasource(Model, BaseDatasource):

name = Column(String(100), nullable=False)
source_url = Column(String(1000), nullable=False)
source_auth = Column(JSONType)
source_parameters = Column(JSONType)
format = Column(ChoiceType(FORMATS), nullable=False)
additional_parameters = Column(JSONType)

Expand Down Expand Up @@ -296,7 +307,9 @@ def get_empty_dataframe(self):

@property
def cache_key(self):
source = {'source_url': self.source_url}
source = {'source_url': self.source_url,
'source_auth': self.source_auth}
source.update(self.source_parameters or {})
source.update(self.pandas_read_parameters)
s = str([(k, source[k]) for k in sorted(source.keys())])
return hashlib.md5(s.encode('utf-8')).hexdigest()
Expand All @@ -313,7 +326,25 @@ def get_dataframe(self):
cache_key = self.cache_key
self.df = dataframe_cache.get(cache_key)
if not isinstance(self.df, pd.DataFrame):
self.df = self.pandas_read_method(self.source_url, **self.pandas_read_parameters)
if isinstance(self.source_url, basestring) and self.source_url[:4] == 'http':
# Use requests to retrieve remote data so we can handle authentication
auth = self.source_auth
url = self.source_url
if isinstance(auth, (tuple, list)):
response = requests.get(url, params=self.source_parameters,
auth=tuple(auth))
elif auth:
response = requests.get(url, params=self.source_parameters,
headers={'Authorization': auth})
else:
response = requests.get(url, params=self.source_parameters)
response.raise_for_status()
data = BytesIO(response.content)
else:
# Local file, so just use Pandas directly
data = self.source_url
# Read the dataframe from the response
self.df = self.pandas_read_method(data, **self.pandas_read_parameters)

# read_html returns a list of DataFrames
if (isinstance(self.df, list) and
Expand Down Expand Up @@ -763,7 +794,6 @@ def get_metadata(self):
"""Build the metadata for the table and merge it in"""
df = self.get_dataframe()

metrics = []
any_date_col = None
dbcols = (
db.session.query(PandasColumn)
Expand All @@ -773,10 +803,20 @@ def get_metadata(self):
dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
for col in df.columns:
dbcol = dbcols.get(col, None)

if not dbcol:
dbcol = PandasColumn(column_name=str(col), type=df.dtypes[col].name)
dbcol.groupby = dbcol.is_string
dbcol.filterable = dbcol.is_string
dtype = df.dtypes[col].name
# Pandas defaults columns where all values are None to a dtype of
# float with all values as NaN, but if we can't correctly infer
# the dtype we are better to assume object so we don't
# create large numbers of unwanted Metrics
if self.df[col].isnull().all():
dtype = 'object'
dbcol = PandasColumn(column_name=str(col), type=dtype)
# Only treat `object` as string if we have some data
if self.df[col].notnull().any():
dbcol.groupby = dbcol.is_string
dbcol.filterable = dbcol.is_string
dbcol.sum = dbcol.is_num
dbcol.avg = dbcol.is_num
dbcol.min = dbcol.is_num or dbcol.is_dttm
Expand All @@ -786,70 +826,86 @@ def get_metadata(self):
if not any_date_col and dbcol.is_time:
any_date_col = col

if dbcol.sum:
metrics.append(PandasMetric(
metric_name='sum__' + dbcol.column_name,
verbose_name='sum__' + dbcol.column_name,
metric_type='sum',
source=dbcol.column_name,
expression='sum'
))
if dbcol.avg:
metrics.append(PandasMetric(
metric_name='avg__' + dbcol.column_name,
verbose_name='avg__' + dbcol.column_name,
metric_type='avg',
source=dbcol.column_name,
expression='mean'
))
if dbcol.max:
metrics.append(PandasMetric(
metric_name='max__' + dbcol.column_name,
verbose_name='max__' + dbcol.column_name,
metric_type='max',
source=dbcol.column_name,
expression='max'
))
if dbcol.min:
metrics.append(PandasMetric(
metric_name='min__' + dbcol.column_name,
verbose_name='min__' + dbcol.column_name,
metric_type='min',
source=dbcol.column_name,
expression='min'
))
if dbcol.count_distinct:
metrics.append(PandasMetric(
metric_name='count_distinct__' + dbcol.column_name,
verbose_name='count_distinct__' + dbcol.column_name,
metric_type='count_distinct',
source=dbcol.column_name,
expression='nunique'
))
dbcol.type = df.dtypes[col].name

metrics.append(PandasMetric(
metric_name='count',
verbose_name='count',
metric_type='count',
source=None,
expression="count"
))
dbmetrics = (
db.session.query(PandasMetric)
.filter(PandasMetric.datasource == self)
.filter(or_(PandasMetric.metric_name == metric.metric_name
for metric in metrics)))
dbmetrics = {metric.metric_name: metric for metric in dbmetrics}
for metric in metrics:
metric.pandas_datasource_id = self.id
if not dbmetrics.get(metric.metric_name, None):
db.session.add(metric)
if not self.main_dttm_col:
self.main_dttm_col = any_date_col

db.session.merge(self)
db.session.commit()


def reconcile_column_metrics(mapper, connection, target):
"""
Create or delete PandasMetrics to match the metric attributes
specified on a PandasColumn
"""
metrics_table = PandasMetric.__table__
for metric_type in ('sum', 'avg', 'max', 'min', 'count_distinct'):
# Set up the metric attributes
metric_name = metric_type + '__' + target.column_name
verbose_name = metric_name
source = target.column_name
if metric_type == 'avg':
expression = 'mean',
elif metric_type == 'count_distinct':
expression = 'nunique'
else:
expression = metric_type

if getattr(target, metric_type):
# Create the metric if it doesn't already exist
result = connection.execute(
metrics_table
.select()
.where(
and_(
metrics_table.c.pandas_datasource_id == target.pandas_datasource_id,
metrics_table.c.metric_name == metric_name)))
if not result.rowcount:
connection.execute(
metrics_table.insert(),
pandas_datasource_id=target.pandas_datasource_id,
metric_name=metric_name,
verbose_name=verbose_name,
source=source,
expression=expression)
else:
# Delete the metric if it exists and hasn't been customized
connection.execute(
metrics_table
.delete()
.where(
and_(
metrics_table.c.pandas_datasource_id == target.pandas_datasource_id,
metrics_table.c.metric_name == metric_name,
metrics_table.c.verbose_name == verbose_name,
metrics_table.c.source == source,
metrics_table.c.expression == expression)))


def reconcile_metric_column(mapper, connection, target):
"""
Clear the metric attribute on a PandasColumn if the
corresponding PandasMetric is deleted
"""
column_table = PandasColumn.__table__
try:
metric_type, column_name = target.metric_name.split('__', 1)
if metric_type in column_table.c:
connection.execute(
column_table
.update()
.values(**{metric_type: False})
.where(
and_(
column_table.c.pandas_datasource_id == target.pandas_datasource_id,
column_table.c.column_name == column_name)))
except ValueError:
# Metric name doesn't contain __
pass


sa.event.listen(PandasColumn, 'after_insert', reconcile_column_metrics)
sa.event.listen(PandasColumn, 'after_update', reconcile_column_metrics)
sa.event.listen(PandasMetric, 'before_delete', reconcile_metric_column)
sa.event.listen(PandasDatasource, 'after_insert', set_perm)
sa.event.listen(PandasDatasource, 'after_update', set_perm)
Loading

0 comments on commit 0a24cb6

Please sign in to comment.