Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add baseline table to output file generated by the tc --sqldb command #2714

Merged
merged 2 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion taxcalc/cli/tc.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def cli_tc_main():
default=None)
parser.add_argument('--sqldb',
help=('optional flag that writes SQLite database '
'with dump table containing same output as '
'with two tables (baseline and reform) each '
'containing same output variables as '
'produced by --dump option.'),
default=False,
action="store_true")
Expand Down
43 changes: 33 additions & 10 deletions taxcalc/taxcalcio.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,9 @@ def analyze(self, writing_output_file=False,
calculated variables using their Tax-Calculator names

output_sqldb: boolean
whether or not to write SQLite3 database with dump table
containing same output as written by output_dump to a csv file
whether or not to write SQLite3 database with two tables
(baseline and reform) each containing same output as written
by output_dump to a csv file

Returns
-------
Expand All @@ -449,18 +450,28 @@ def analyze(self, writing_output_file=False,
(mtr_paytax, mtr_inctax,
_) = self.calc.mtr(wrt_full_compensation=False,
calc_all_already_called=True)
self.calc_base.calc_all()
calc_base_calculated = True
(mtr_paytax_base, mtr_inctax_base,
_) = self.calc_base.mtr(wrt_full_compensation=False,
calc_all_already_called=True)
else:
# definitely do not need marginal tax rates
mtr_paytax = None
mtr_inctax = None
mtr_paytax_base = None
mtr_inctax_base = None
# extract output if writing_output_file
if writing_output_file:
self.write_output_file(output_dump, dump_varset,
mtr_paytax, mtr_inctax)
self.write_doc_file()
# optionally write --sqldb output to SQLite3 database
if output_sqldb:
self.write_sqldb_file(dump_varset, mtr_paytax, mtr_inctax)
self.write_sqldb_file(
dump_varset, mtr_paytax, mtr_inctax,
mtr_paytax_base, mtr_inctax_base
)
# optionally write --tables output to text file
if output_tables:
if not calc_base_calculated:
Expand All @@ -480,7 +491,9 @@ def write_output_file(self, output_dump, dump_varset,
Write output to CSV-formatted file.
"""
if output_dump:
outdf = self.dump_output(dump_varset, mtr_inctax, mtr_paytax)
outdf = self.dump_output(
self.calc, dump_varset, mtr_inctax, mtr_paytax
)
column_order = sorted(outdf.columns)
else:
outdf = self.minimal_output()
Expand All @@ -504,15 +517,25 @@ def write_doc_file(self):
with open(doc_fname, 'w') as dfile:
dfile.write(doc)

def write_sqldb_file(self, dump_varset, mtr_paytax, mtr_inctax):
def write_sqldb_file(self, dump_varset, mtr_paytax, mtr_inctax,
mtr_paytax_base, mtr_inctax_base):
"""
Write dump output to SQLite3 database table dump.
"""
outdf = self.dump_output(dump_varset, mtr_inctax, mtr_paytax)
assert len(outdf.index) == self.calc.array_len
db_fname = self._output_filename.replace('.csv', '.db')
dbcon = sqlite3.connect(db_fname)
outdf.to_sql('dump', dbcon, if_exists='replace', index=False)
# write baseline table
outdf = self.dump_output(
self.calc_base, dump_varset, mtr_inctax_base, mtr_paytax_base
)
assert len(outdf.index) == self.calc.array_len
outdf.to_sql('baseline', dbcon, if_exists='replace', index=False)
# write reform table
outdf = self.dump_output(
self.calc, dump_varset, mtr_inctax, mtr_paytax
)
assert len(outdf.index) == self.calc.array_len
outdf.to_sql('reform', dbcon, if_exists='replace', index=False)
dbcon.close()
del outdf
gc.collect()
Expand Down Expand Up @@ -687,7 +710,7 @@ def minimal_output(self):
odf = pd.DataFrame(data=odict, columns=varlist)
return odf

def dump_output(self, dump_varset, mtr_inctax, mtr_paytax):
def dump_output(self, calcx, dump_varset, mtr_inctax, mtr_paytax):
"""
Extract dump output and return it as Pandas DataFrame.
"""
Expand All @@ -699,7 +722,7 @@ def dump_output(self, dump_varset, mtr_inctax, mtr_paytax):
# create and return dump output DataFrame
odf = pd.DataFrame()
for varname in varset:
vardata = self.calc.array(varname)
vardata = calcx.array(varname)
if varname in recs_vinfo.INTEGER_VARS:
odf[varname] = vardata
else:
Expand Down
Loading