diff --git a/caravel/assets/javascripts/modules/caravel.js b/caravel/assets/javascripts/modules/caravel.js index 88c11e7967fdc..84d5b5f135796 100644 --- a/caravel/assets/javascripts/modules/caravel.js +++ b/caravel/assets/javascripts/modules/caravel.js @@ -19,6 +19,7 @@ var sourceMap = { markup: 'markup.js', para: 'parallel_coordinates.js', pie: 'nvd3_vis.js', + box_plot: 'nvd3_vis.js', pivot_table: 'pivot_table.js', sankey: 'sankey.js', sunburst: 'sunburst.js', @@ -45,6 +46,7 @@ var color = function () { // Color factory var seen = {}; return function (s) { + if (!s) { return; } // next line is for caravel series that should have the same color s = s.replace('---', ''); if (seen[s] === undefined) { diff --git a/caravel/assets/visualizations/nvd3_vis.js b/caravel/assets/visualizations/nvd3_vis.js index db4099456b05c..1df5bebc69b74 100644 --- a/caravel/assets/visualizations/nvd3_vis.js +++ b/caravel/assets/visualizations/nvd3_vis.js @@ -120,6 +120,14 @@ function nvd3Vis(slice) { .staggerLabels(true); break; + case 'box_plot': + colorKey = 'label'; + chart = nv.models.boxPlotChart(); + chart.x(function (d) { return d.label; }); + chart.staggerLabels(true); + chart.maxBoxWidth(75); // prevent boxes from being incredibly wide + break; + default: throw new Error("Unrecognized visualization for nvd3" + viz_type); } diff --git a/caravel/data/__init__.py b/caravel/data/__init__.py index 03969f30084cb..a047e1be3054f 100644 --- a/caravel/data/__init__.py +++ b/caravel/data/__init__.py @@ -263,6 +263,18 @@ def load_world_bank_health_n_pop(): until="now", viz_type='area', groupby=["region"],)), + Slice( + slice_name="Box plot", + viz_type='box_plot', + datasource_type='table', + table=tbl, + params=get_slice_json( + defaults, + since="1960-01-01", + until="now", + whisker_options="Tukey", + viz_type='box_plot', + groupby=["region"],)), ] for slc in slices: merge_slice(slc) diff --git a/caravel/forms.py b/caravel/forms.py index 062427cdb1f42..37893e3711eff 100644 --- a/caravel/forms.py +++ b/caravel/forms.py @@ -315,6 +315,17 @@ def __init__(self, viz): '100', ]) ), + 'whisker_options': FreeFormSelectField( + 'Whisker/outlier options', default="Tukey", + description=( + "Determines how whiskers and outliers are calculated."), + choices=self.choicify([ + 'Tukey', + 'Min/max (no outliers)', + '2/98 percentiles', + '9/91 percentiles', + ]) + ), 'row_limit': FreeFormSelectField( 'Row limit', diff --git a/caravel/viz.py b/caravel/viz.py index 9592235e36b4f..ed039cf9ce922 100644 --- a/caravel/viz.py +++ b/caravel/viz.py @@ -16,6 +16,7 @@ from datetime import datetime, timedelta import pandas as pd +import numpy as np from flask import flash, request, Markup from markdown import markdown from pandas.io.json import dumps @@ -488,6 +489,113 @@ class NVD3Viz(BaseViz): is_timeseries = False +class BoxPlotViz(NVD3Viz): + + """Box plot viz from ND3""" + + viz_type = "box_plot" + verbose_name = "Box Plot" + sort_series = False + is_timeseries = True + fieldsets = ({ + 'label': None, + 'fields': ( + 'metrics', + 'groupby', 'limit', + ), + }, { + 'label': 'Chart Options', + 'fields': ( + 'whisker_options', + ) + },) + + def get_df(self, query_obj=None): + form_data = self.form_data + df = super(BoxPlotViz, self).get_df(query_obj) + + df = df.fillna(0) + + # conform to NVD3 names + def Q1(series): # need to be named functions - can't use lambdas + return np.percentile(series, 25) + + def Q3(series): + return np.percentile(series, 75) + + whisker_type = form_data.get('whisker_options') + if whisker_type == "Tukey": + + def whisker_high(series): + upper_outer_lim = Q3(series) + 1.5 * (Q3(series) - Q1(series)) + series = series[series <= upper_outer_lim] + return series[np.abs(series - upper_outer_lim).argmin()] + + def whisker_low(series): + lower_outer_lim = Q1(series) - 1.5 * (Q3(series) - Q1(series)) + # find the closest value above the lower outer limit + series = series[series >= lower_outer_lim] + return series[np.abs(series - lower_outer_lim).argmin()] + + elif whisker_type == "Min/max (no outliers)": + + def whisker_high(series): + return series.max() + + def whisker_low(series): + return series.min() + + elif " percentiles" in whisker_type: + low, high = whisker_type.replace(" percentiles", "").split("/") + + def whisker_high(series): + return np.percentile(series, int(high)) + + def whisker_low(series): + return np.percentile(series, int(low)) + + else: + raise ValueError("Unknown whisker type: {}".format(whisker_type)) + + def outliers(series): + above = series[series > whisker_high(series)] + below = series[series < whisker_low(series)] + # pandas sometimes doesn't like getting lists back here + return set(above.tolist() + below.tolist()) + + aggregate = [Q1, np.median, Q3, whisker_high, whisker_low, outliers] + df = df.groupby(form_data.get('groupby')).agg(aggregate) + return df + + def to_series(self, df, classed='', title_suffix=''): + label_sep = " - " + chart_data = [] + for index_value, row in zip(df.index, df.to_dict(orient="records")): + if isinstance(index_value, tuple): + index_value = label_sep.join(index_value) + boxes = defaultdict(dict) + for (label, key), value in row.items(): + if key == "median": + key = "Q2" + boxes[label][key] = value + for label, box in boxes.items(): + if len(self.form_data.get("metrics")) > 1: + # need to render data labels with metrics + chart_label = label_sep.join([index_value, label]) + else: + chart_label = index_value + chart_data.append({ + "label": chart_label, + "values": box, + }) + return chart_data + + def get_data(self): + df = self.get_df() + chart_data = self.to_series(df) + return chart_data + + class BubbleViz(NVD3Viz): """Based on the NVD3 bubble chart""" @@ -1387,6 +1495,7 @@ def get_data(self): IFrameViz, ParallelCoordinatesViz, HeatmapViz, + BoxPlotViz, ] viz_types = OrderedDict([(v.viz_type, v) for v in viz_types_list])