Skip to content

Commit

Permalink
Merge pull request #1606 from martinholmer/decile-change-graph
Browse files Browse the repository at this point in the history
Add utility functions that generate a change-in-aftertax-income-by-decile graph
  • Loading branch information
martinholmer authored Nov 4, 2017
2 parents 4eec5ac + e2d724c commit a3614a1
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 1 deletion.
28 changes: 28 additions & 0 deletions taxcalc/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
add_income_bins, add_quantile_bins,
multiyear_diagnostic_table,
mtr_graph_data, atr_graph_data,
dec_graph_data, dec_graph_plot,
xtr_graph_plot, write_graph_file,
read_egg_csv, read_egg_json, delete_file,
bootstrap_se_ci,
Expand Down Expand Up @@ -950,3 +951,30 @@ def test_table_columns_labels():
# check that length of two lists are the same
assert len(DIST_TABLE_COLUMNS) == len(DIST_TABLE_LABELS)
assert len(DIFF_TABLE_COLUMNS) == len(DIFF_TABLE_LABELS)


def test_dec_graph_plot(cps_subsample):
pol = Policy()
rec = Records.cps_constructor(data=cps_subsample)
calc1 = Calculator(policy=pol, records=rec)
year = 2020
reform = {
year: {
'_SS_Earnings_c': [9e99], # OASDI FICA tax on all earnings
'_FICA_ss_trt': [0.107484] # lower rate to keep revenue unchanged

}
}
pol.implement_reform(reform)
calc2 = Calculator(policy=pol, records=rec)
calc1.advance_to_year(year)
with pytest.raises(ValueError):
dec_graph_data(calc1, calc2)
calc2.advance_to_year(year)
gdata = dec_graph_data(calc1, calc2)
assert isinstance(gdata, dict)
deciles = gdata['bars'].keys()
assert len(deciles) == 14
gplot = dec_graph_plot(gdata, xlabel='', ylabel='')
assert gplot
# write_graph_file(gplot, 'test.html', 'Test Plot')
160 changes: 159 additions & 1 deletion taxcalc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pandas as pd
import bokeh.io as bio
import bokeh.plotting as bp
from bokeh.models import PrintfTickFormatter
from taxcalc.utilsprvt import (weighted_count_lt_zero,
weighted_count_gt_zero,
weighted_count, weighted_mean,
Expand Down Expand Up @@ -107,7 +108,7 @@
DECILE_ROW_NAMES = ['0-10', '10-20', '20-30', '30-40', '40-50',
'50-60', '60-70', '70-80', '80-90', '90-100',
'all',
'90-95', '95-99', '99-100']
'90-95', '95-99', 'Top 1%']

WEBAPP_INCOME_BINS = [-9e99, 0, 9999, 19999, 29999, 39999, 49999, 74999, 99999,
199999, 499999, 1000000, 9e99]
Expand Down Expand Up @@ -1382,3 +1383,160 @@ def bootstrap_se_ci(data, seed, num_samples, statistic, alpha):
bsest['cilo'] = stat[int(round(alpha * num_samples)) - 1]
bsest['cihi'] = stat[int(round((1 - alpha) * num_samples)) - 1]
return bsest


def dec_graph_data(calc1, calc2):
"""
Prepare data needed by dec_graph_plot utility function.
Parameters
----------
calc1 : a Calculator object that refers to baseline policy
calc2 : a Calculator object that refers to reform policy
Returns
-------
dictionary object suitable for passing to dec_graph_plot utility function
"""
# check that two calculator objects have the same current_year
if calc1.current_year == calc2.current_year:
year = calc1.current_year
else:
msg = 'calc1.current_year={} != calc2.current_year={}'
raise ValueError(msg.format(calc1.current_year, calc2.current_year))
# create difference table from the two Calculator objects
calc1.calc_all()
calc2.calc_all()
diff_table = create_difference_table(calc1.records, calc2.records,
groupby='weighted_deciles',
income_measure='expanded_income',
tax_to_diff='combined')
# construct dictionary containing the bar data required by dec_graph_plot
bars = dict()
for idx in range(0, 14): # the ten income deciles, all, plus top details
info = dict()
info['label'] = DECILE_ROW_NAMES[idx]
info['value'] = diff_table['pc_aftertaxinc'][idx]
if info['label'] == 'all':
info['label'] = '---------'
info['value'] = 0
bars[idx] = info
# construct dictionary containing bar data and auto-generated labels
data = dict()
data['bars'] = bars
xlabel = 'Reform-Induced Percentage Change in After-Tax Expanded Income'
data['xlabel'] = xlabel
ylabel = 'Expanded Income Percentile Group'
data['ylabel'] = ylabel
title_str = 'Change in After-Tax Income by Income Percentile Group'
data['title'] = '{} for {}'.format(title_str, year)
return data


def dec_graph_plot(data,
width=850,
height=500,
xlabel='',
ylabel='',
title=''):
"""
Plot stacked decile graph using data returned from dec_graph_data function.
Parameters
----------
data : dictionary object returned from dec_graph_data() utility function
width : integer
width of plot expressed in pixels
height : integer
height of plot expressed in pixels
xlabel : string
x-axis label; if '', then use label generated by dec_graph_data
ylabel : string
y-axis label; if '', then use label generated by dec_graph_data
title : string
graph title; if '', then use title generated by dec_graph_data
Returns
-------
bokeh.plotting figure object containing a raster graphics plot
Notes
-----
USAGE EXAMPLE::
gdata = dec_graph_data(calc1, calc2)
gplot = dec_graph_plot(gdata)
THEN when working interactively in a Python notebook::
bp.show(gplot)
OR when executing script using Python command-line interpreter::
bio.output_file('graph-name.html', title='Change in After-Tax Income')
bio.show(gplot) [OR bio.save(gplot) WILL JUST WRITE FILE TO DISK]
WILL VISUALIZE GRAPH IN BROWSER AND WRITE GRAPH TO SPECIFIED HTML FILE
To convert the visualized graph into a PNG-formatted file, click on
the "Save" icon on the Toolbar (located in the top-right corner of
the visualized graph) and a PNG-formatted file will written to your
Download directory.
The ONLY output option the bokeh.plotting figure has is HTML format,
which (as described above) can be converted into a PNG-formatted
raster graphics file. There is no option to make the bokeh.plotting
figure generate a vector graphics file such as an EPS file.
"""
# pylint: disable=too-many-arguments,too-many-locals
if title == '':
title = data['title']
bar_keys = sorted(data['bars'].keys())
bar_labels = [data['bars'][key]['label'] for key in bar_keys]
fig = bp.figure(plot_width=width, plot_height=height, title=title,
y_range=bar_labels)
fig.title.text_font_size = '12pt'
fig.outline_line_color = None
fig.axis.axis_line_color = None
fig.axis.minor_tick_line_color = None
fig.axis.axis_label_text_font_size = '12pt'
fig.axis.axis_label_text_font_style = 'normal'
fig.axis.major_label_text_font_size = '12pt'
if xlabel == '':
xlabel = data['xlabel']
fig.xaxis.axis_label = xlabel
fig.xaxis[0].formatter = PrintfTickFormatter(format='%+d%%')
if ylabel == '':
ylabel = data['ylabel']
fig.yaxis.axis_label = ylabel
fig.ygrid.grid_line_color = None
# plot thick x-axis grid line at zero
fig.line(x=[0, 0], y=[0, 14], line_width=1, line_color='black')
# plot bars
barheight = 0.8
bcolor = 'blue'
yidx = 0
for idx in bar_keys:
bval = data['bars'][idx]['value']
blabel = data['bars'][idx]['label']
bheight = barheight
if blabel == '90-95':
bheight *= 0.5
bcolor = 'red'
elif blabel == '95-99':
bheight *= 0.4
elif blabel == 'Top 1%':
bheight *= 0.1
fig.rect(x=(bval / 2.0), # x-coordinate of center of the rectangle
y=(yidx + 0.5), # y-coordinate of center of the rectangle
width=abs(bval), # width of the rectangle
height=bheight, # height of the rectangle
color=bcolor)
yidx += 1
return fig

0 comments on commit a3614a1

Please sign in to comment.