Skip to content

Commit

Permalink
Make MetricsControl the standard across visualizations (apache#4914)
Browse files Browse the repository at this point in the history
* [WiP] make MetricsControl the standard across visualizations

This spreads MetricsControl across visualizations.

* Addressing comments

* Fix deepcopy issue using shallow copy

* Fix tests
  • Loading branch information
mistercrunch authored and michellethomas committed May 23, 2018
1 parent 971480f commit a1b6580
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 104 deletions.
117 changes: 38 additions & 79 deletions superset/assets/src/explore/controls.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ import {
import * as v from './validators';
import { colorPrimary, ALL_COLOR_SCHEMES, spectrums } from '../modules/colors';
import { defaultViewport } from '../modules/geo';
import MetricOption from '../components/MetricOption';
import ColumnOption from '../components/ColumnOption';
import OptionDescription from '../components/OptionDescription';
import { t } from '../locales';
Expand Down Expand Up @@ -116,6 +115,32 @@ const groupByControl = {
},
};

const metrics = {
type: 'MetricsControl',
multi: true,
label: t('Metrics'),
validators: [v.nonEmpty],
default: (c) => {
const metric = mainMetric(c.savedMetrics);
return metric ? [metric] : null;
},
mapStateToProps: (state) => {
const datasource = state.datasource;
return {
columns: datasource ? datasource.columns : [],
savedMetrics: datasource ? datasource.metrics : [],
datasourceType: datasource && datasource.type,
};
},
description: t('One or many metrics to display'),
};
const metric = {
...metrics,
multi: false,
label: t('Metric'),
default: props => mainMetric(props.savedMetrics),
};

const sandboxUrl = (
'https://github.com/apache/incubator-superset/' +
'blob/master/superset/assets/src/modules/sandbox.js');
Expand Down Expand Up @@ -152,6 +177,11 @@ function jsFunctionControl(label, description, extraDescr = null, height = 100,
}

export const controls = {

metrics,

metric,

datasource: {
type: 'DatasourceControl',
label: t('Datasource'),
Expand All @@ -169,36 +199,11 @@ export const controls = {
description: t('The type of visualization to display'),
},

metrics: {
type: 'MetricsControl',
multi: true,
label: t('Metrics'),
validators: [v.nonEmpty],
default: (c) => {
const metric = mainMetric(c.savedMetrics);
return metric ? [metric] : null;
},
mapStateToProps: (state) => {
const datasource = state.datasource;
return {
columns: datasource ? datasource.columns : [],
savedMetrics: datasource ? datasource.metrics : [],
datasourceType: datasource && datasource.type,
};
},
description: t('One or many metrics to display'),
},

percent_metrics: {
type: 'SelectControl',
...metrics,
multi: true,
label: t('Percentage Metrics'),
valueKey: 'metric_name',
optionRenderer: m => <MetricOption metric={m} showType />,
valueRenderer: m => <MetricOption metric={m} />,
mapStateToProps: state => ({
options: (state.datasource) ? state.datasource.metrics : [],
}),
validators: [],
description: t('Metrics for which percentage of total are to be displayed'),
},

Expand Down Expand Up @@ -262,33 +267,11 @@ export const controls = {
renderTrigger: true,
},

metric: {
type: 'MetricsControl',
multi: false,
label: t('Metric'),
clearable: false,
validators: [v.nonEmpty],
default: props => mainMetric(props.savedMetrics),
mapStateToProps: state => ({
columns: state.datasource ? state.datasource.columns : [],
savedMetrics: state.datasource ? state.datasource.metrics : [],
datasourceType: state.datasource && state.datasource.type,
}),
},

metric_2: {
type: 'SelectControl',
...metric,
label: t('Right Axis Metric'),
default: null,
validators: [v.nonEmpty],
clearable: true,
description: t('Choose a metric for right axis'),
valueKey: 'metric_name',
optionRenderer: m => <MetricOption metric={m} showType />,
valueRenderer: m => <MetricOption metric={m} />,
mapStateToProps: state => ({
options: (state.datasource) ? state.datasource.metrics : [],
}),
},

stacked_style: {
Expand Down Expand Up @@ -508,13 +491,10 @@ export const controls = {
},

secondary_metric: {
type: 'SelectControl',
...metric,
label: t('Color Metric'),
default: null,
description: t('A metric to use for color'),
mapStateToProps: state => ({
choices: (state.datasource) ? state.datasource.metrics_combo : [],
}),
},
select_country: {
type: 'SelectControl',
Expand Down Expand Up @@ -1105,44 +1085,23 @@ export const controls = {
},

x: {
type: 'SelectControl',
...metric,
label: t('X Axis'),
description: t('Metric assigned to the [X] axis'),
default: null,
validators: [v.nonEmpty],
optionRenderer: m => <MetricOption metric={m} showType />,
valueRenderer: m => <MetricOption metric={m} />,
valueKey: 'metric_name',
mapStateToProps: state => ({
options: (state.datasource) ? state.datasource.metrics : [],
}),
},

y: {
type: 'SelectControl',
...metric,
label: t('Y Axis'),
default: null,
validators: [v.nonEmpty],
description: t('Metric assigned to the [Y] axis'),
optionRenderer: m => <MetricOption metric={m} showType />,
valueRenderer: m => <MetricOption metric={m} />,
valueKey: 'metric_name',
mapStateToProps: state => ({
options: (state.datasource) ? state.datasource.metrics : [],
}),
},

size: {
type: 'SelectControl',
...metric,
label: t('Bubble Size'),
default: null,
validators: [v.nonEmpty],
optionRenderer: m => <MetricOption metric={m} showType />,
valueRenderer: m => <MetricOption metric={m} />,
valueKey: 'metric_name',
mapStateToProps: state => ({
options: (state.datasource) ? state.datasource.metrics : [],
}),
},

url: {
Expand Down
4 changes: 2 additions & 2 deletions superset/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,10 +1168,10 @@ def load_multiformat_time_series_data():
obj.fetch_metadata()
tbl = obj

print("Creating some slices")
print("Creating Heatmap charts")
for i, col in enumerate(tbl.columns):
slice_data = {
"metric": 'count',
"metrics": ['count'],
"granularity_sqla": col.column_name,
"granularity_sqla": "day",
"row_limit": config.get("ROW_LIMIT"),
Expand Down
70 changes: 47 additions & 23 deletions superset/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
config = app.config
stats_logger = config.get('STATS_LOGGER')

METRIC_KEYS = [
'metric', 'metrics', 'percent_metrics', 'metric_2', 'secondary_metric',
'x', 'y', 'size',
]


class BaseViz(object):

Expand All @@ -66,13 +71,6 @@ def __init__(self, datasource, form_data, force=False):
self.query = ''
self.token = self.form_data.get(
'token', 'token_' + uuid.uuid4().hex[:8])
metrics = self.form_data.get('metrics') or []
self.metrics = []
for metric in metrics:
if isinstance(metric, dict):
self.metrics.append(metric['label'])
else:
self.metrics.append(metric)

self.groupby = self.form_data.get('groupby') or []
self.time_shift = timedelta()
Expand All @@ -90,6 +88,29 @@ def __init__(self, datasource, form_data, force=False):
self._any_cached_dttm = None
self._extra_chart_data = None

self.process_metrics()

def process_metrics(self):
self.metric_dict = {}
fd = self.form_data
for mkey in METRIC_KEYS:
val = fd.get(mkey)
if val:
if not isinstance(val, list):
val = [val]
for o in val:
self.metric_dict[self.get_metric_label(o)] = o

# Cast to list needed to return serializable object in py3
self.all_metrics = list(self.metric_dict.values())
self.metric_labels = list(self.metric_dict.keys())

def get_metric_label(self, metric):
if isinstance(metric, string_types):
return metric
if isinstance(metric, dict):
return metric.get('label')

@staticmethod
def handle_js_int_overflow(data):
for d in data.get('records', dict()):
Expand Down Expand Up @@ -202,7 +223,7 @@ def query_obj(self):
"""Building a query object"""
form_data = self.form_data
gb = form_data.get('groupby') or []
metrics = form_data.get('metrics') or []
metrics = self.all_metrics or []
columns = form_data.get('columns') or []
groupby = []
for o in gb + columns:
Expand Down Expand Up @@ -346,7 +367,7 @@ def cache_key(self, query_obj):
and replace them with the use-provided inputs to bounds, which
may we time-relative (as in "5 days ago" or "now").
"""
cache_dict = copy.deepcopy(query_obj)
cache_dict = copy.copy(query_obj)

for k in ['from_dttm', 'to_dttm']:
del cache_dict[k]
Expand Down Expand Up @@ -520,7 +541,7 @@ def query_obj(self):
'Choose either fields to [Group By] and [Metrics] or '
'[Columns], not both'))

sort_by = fd.get('timeseries_limit_metric')
sort_by = fd.get('timeseries_limit_metric') or []
if fd.get('all_columns'):
d['columns'] = fd.get('all_columns')
d['groupby'] = []
Expand All @@ -535,7 +556,7 @@ def query_obj(self):
if 'percent_metrics' in fd:
d['metrics'] = d['metrics'] + list(filter(
lambda m: m not in d['metrics'],
fd['percent_metrics'],
fd['percent_metrics'] or [],
))

d['is_timeseries'] = self.should_be_timeseries()
Expand All @@ -551,7 +572,8 @@ def get_data(self, df):
del df[DTTM_ALIAS]

# Sum up and compute percentages for all percent metrics
percent_metrics = fd.get('percent_metrics', [])
percent_metrics = fd.get('percent_metrics') or []

if len(percent_metrics):
percent_metrics = list(filter(lambda m: m in df, percent_metrics))
metric_sums = {
Expand Down Expand Up @@ -611,10 +633,10 @@ def query_obj(self):

def get_data(self, df):
fd = self.form_data
values = self.metrics
columns = None
values = self.metric_labels
if fd.get('groupby'):
values = self.metrics[0]
values = self.metric_labels[0]
columns = fd.get('groupby')
pt = df.pivot_table(
index=DTTM_ALIAS,
Expand Down Expand Up @@ -780,7 +802,7 @@ def get_data(self, df):

data = {}
records = df.to_dict('records')
for metric in self.metrics:
for metric in self.metric_labels:
data[metric] = {
str(obj[DTTM_ALIAS].value / 10**9): obj.get(metric)
for obj in records
Expand Down Expand Up @@ -1109,7 +1131,7 @@ def to_series(self, df, classed='', title_suffix=''):
if (
isinstance(series_title, (list, tuple)) and
len(series_title) > 1 and
len(self.metrics) == 1):
len(self.metric_labels) == 1):
# Removing metric from series name if only one metric
series_title = series_title[1:]
if title_suffix:
Expand Down Expand Up @@ -1393,10 +1415,11 @@ class DistributionPieViz(NVD3Viz):
is_timeseries = False

def get_data(self, df):
metric = self.metric_labels[0]
df = df.pivot_table(
index=self.groupby,
values=[self.metrics[0]])
df.sort_values(by=self.metrics[0], ascending=False, inplace=True)
values=[metric])
df.sort_values(by=metric, ascending=False, inplace=True)
df = df.reset_index()
df.columns = ['x', 'y']
return df.to_dict(orient='records')
Expand Down Expand Up @@ -1468,14 +1491,15 @@ def query_obj(self):

def get_data(self, df):
fd = self.form_data
metrics = self.metric_labels

row = df.groupby(self.groupby).sum()[self.metrics[0]].copy()
row = df.groupby(self.groupby).sum()[metrics[0]].copy()
row.sort_values(ascending=False, inplace=True)
columns = fd.get('columns') or []
pt = df.pivot_table(
index=self.groupby,
columns=columns,
values=self.metrics)
values=metrics)
if fd.get('contribution'):
pt = pt.fillna(0)
pt = pt.T
Expand All @@ -1487,7 +1511,7 @@ def get_data(self, df):
continue
if isinstance(name, string_types):
series_title = name
elif len(self.metrics) > 1:
elif len(metrics) > 1:
series_title = ', '.join(name)
else:
l = [str(s) for s in name[1:]] # noqa: E741
Expand Down Expand Up @@ -1664,7 +1688,7 @@ def query_obj(self):
def get_data(self, df):
fd = self.form_data
cols = [fd.get('entity')]
metric = fd.get('metric')
metric = self.metric_labels[0]
cols += [metric]
ndf = df[cols]
df = ndf
Expand Down Expand Up @@ -1836,7 +1860,7 @@ def get_data(self, df):
fd = self.form_data
x = fd.get('all_columns_x')
y = fd.get('all_columns_y')
v = fd.get('metric')
v = self.metric_labels[0]
if x == y:
df.columns = ['x', 'y', 'v']
else:
Expand Down

0 comments on commit a1b6580

Please sign in to comment.