Skip to content

Commit

Permalink
Merge pull request #857 from jdebacker/plots
Browse files Browse the repository at this point in the history
Merging
  • Loading branch information
rickecon authored Feb 16, 2023
2 parents 4fb68cb + be52d62 commit 21c2a5c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 45 deletions.
2 changes: 1 addition & 1 deletion ogcore/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"B": "Wealth ($B_t$)",
"I_total": "Investment ($I_t$)",
"K": "Capital Stock ($K_t$)",
"Y_vec": "GDP ($Y_t$)",
"Y_vec": "Output ($Y_t$)",
"C_vec": "Consumption ($C_t$)",
"L_vec": "Labor ($L_t$)",
"K_vec": "Capital Stock ($K_t$)",
Expand Down
51 changes: 26 additions & 25 deletions ogcore/output_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def plot_aggregates(
plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncol=2)
if path:
fig_path1 = os.path.join(path)
plt.savefig(fig_path1, bbox_inches="tight")
plt.savefig(fig_path1, bbox_inches="tight", dpi=300)
else:
return fig1
plt.close()
Expand All @@ -163,6 +163,7 @@ def plot_industry_aggregates(
reform_tpi=None,
reform_params=None,
var_list=["Y_vec"],
ind_names_list=None,
plot_type="pct_diff",
num_years_to_plot=50,
start_year=DEFAULT_START_YEAR,
Expand Down Expand Up @@ -207,6 +208,11 @@ def plot_industry_aggregates(
"""
assert isinstance(start_year, (int, np.integer))
assert isinstance(num_years_to_plot, int)
dims = base_tpi[var_list[0]].shape[1]
if ind_names_list:
assert len(ind_names_list) == dims
else:
ind_names_list = [str(i) for i in range(dims)]
# Make sure both runs cover same time period
if reform_tpi:
assert base_params.start_year == reform_params.start_year
Expand All @@ -217,7 +223,11 @@ def plot_industry_aggregates(
assert reform_tpi is not None
fig1, ax1 = plt.subplots()
for i, v in enumerate(var_list):
for m in range(base_params.M):
if len(var_list) == 1:
var_label = ""
else:
var_label = VAR_LABELS[v]
for m in range(dims):
if plot_type == "pct_diff":
plot_var = (
reform_tpi[v][:, m] - base_tpi[v][:, m]
Expand All @@ -226,37 +236,31 @@ def plot_industry_aggregates(
plt.plot(
year_vec,
plot_var[start_index : start_index + num_years_to_plot],
label=VAR_LABELS[v] + "for industry " + str(m),
label=var_label + " " + ind_names_list[m],
)
elif plot_type == "diff":
plot_var = reform_tpi[v][:, m] - base_tpi[v][:, m]
ylabel = r"Difference (Model Units)"
plt.plot(
year_vec,
plot_var[start_index : start_index + num_years_to_plot],
label=VAR_LABELS[v] + "for industry " + str(m),
label=var_label + " " + ind_names_list[m],
)
elif plot_type == "levels":
plt.plot(
year_vec,
base_tpi[v][
start_index : start_index + num_years_to_plot, m
],
label="Baseline "
+ VAR_LABELS[v]
+ "for industry "
+ str(m),
label="Baseline " + var_label + " " + ind_names_list[m],
)
if reform_tpi:
plt.plot(
year_vec,
reform_tpi[v][
start_index : start_index + num_years_to_plot, m
],
label="Reform "
+ VAR_LABELS[v]
+ "for industry "
+ str(m),
label="Reform " + var_label + " " + ind_names_list[m],
)
ylabel = r"Model Units"
elif plot_type == "forecast":
Expand All @@ -267,10 +271,7 @@ def plot_industry_aggregates(
plt.plot(
year_vec,
plot_var_base,
label="Baseline "
+ VAR_LABELS[v]
+ "for industry "
+ str(m),
label="Baseline " + var_label + " " + ind_names_list[m],
)
# Plot change from baseline forecast
pct_change = (
Expand All @@ -287,7 +288,7 @@ def plot_industry_aggregates(
plt.plot(
year_vec,
plot_var_reform,
label="Reform " + VAR_LABELS[v] + "for industry " + str(m),
label="Reform " + var_label + " " + ind_names_list[m],
)
# making units labels will not work if multiple variables
# and they are in different units
Expand Down Expand Up @@ -316,7 +317,7 @@ def plot_industry_aggregates(
plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncol=2)
if path:
fig_path1 = os.path.join(path)
plt.savefig(fig_path1, bbox_inches="tight")
plt.savefig(fig_path1, bbox_inches="tight", dpi=300)
else:
return fig1
plt.close()
Expand Down Expand Up @@ -380,7 +381,7 @@ def ss_3Dplot(
if plot_title:
plt.title(plot_title)
if path:
plt.savefig(path)
plt.savefig(path, dpi=300)
else:
return plt

Expand Down Expand Up @@ -502,7 +503,7 @@ def plot_gdp_ratio(
plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncol=2)
if path:
fig_path1 = os.path.join(path)
plt.savefig(fig_path1, bbox_inches="tight")
plt.savefig(fig_path1, bbox_inches="tight", dpi=300)
else:
return fig1
plt.close()
Expand Down Expand Up @@ -576,7 +577,7 @@ def ability_bar(
plt.title(plot_title, fontsize=15)
if path:
fig_path1 = os.path.join(path)
plt.savefig(fig_path1, bbox_inches="tight")
plt.savefig(fig_path1, bbox_inches="tight", dpi=300)
else:
return fig
plt.close()
Expand Down Expand Up @@ -630,7 +631,7 @@ def ability_bar_ss(
plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncol=2)
if path:
fig_path1 = os.path.join(path)
plt.savefig(fig_path1, bbox_inches="tight")
plt.savefig(fig_path1, bbox_inches="tight", dpi=300)
else:
return fig
plt.close()
Expand Down Expand Up @@ -717,7 +718,7 @@ def tpi_profiles(
plt.title(plot_title, fontsize=15)
if path:
fig_path1 = os.path.join(path)
plt.savefig(fig_path1, bbox_inches="tight")
plt.savefig(fig_path1, bbox_inches="tight", dpi=300)
else:
return fig1
plt.close()
Expand Down Expand Up @@ -797,7 +798,7 @@ def ss_profiles(
plt.title(plot_title, fontsize=15)
if path:
fig_path1 = os.path.join(path)
plt.savefig(fig_path1, bbox_inches="tight")
plt.savefig(fig_path1, bbox_inches="tight", dpi=300)
else:
return fig1
plt.close()
Expand Down Expand Up @@ -1166,7 +1167,7 @@ def inequality_plot(
plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncol=2)
if path:
fig_path1 = os.path.join(path)
plt.savefig(fig_path1, bbox_inches="tight")
plt.savefig(fig_path1, bbox_inches="tight", dpi=300)
else:
return fig1
plt.close()
38 changes: 19 additions & 19 deletions ogcore/parameter_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def plot_imm_rates(p, year=DEFAULT_START_YEAR, include_title=False, path=None):
return fig
else:
fig_path = os.path.join(path, "imm_rates_orig")
plt.savefig(fig_path)
plt.savefig(fig_path, dpi=300)


def plot_mort_rates(p, include_title=False, path=None):
Expand Down Expand Up @@ -66,7 +66,7 @@ def plot_mort_rates(p, include_title=False, path=None):
return fig
else:
fig_path = os.path.join(path, "mortality_rates")
plt.savefig(fig_path)
plt.savefig(fig_path, dpi=300)


def plot_pop_growth(
Expand Down Expand Up @@ -106,7 +106,7 @@ def plot_pop_growth(
return fig
else:
fig_path = os.path.join(path, "pop_growth_rates")
plt.savefig(fig_path)
plt.savefig(fig_path, dpi=300)


def plot_population(p, years_to_plot=["SS"], include_title=False, path=None):
Expand Down Expand Up @@ -145,7 +145,7 @@ def plot_population(p, years_to_plot=["SS"], include_title=False, path=None):
return fig
else:
fig_path = os.path.join(path, "pop_distribution")
plt.savefig(fig_path)
plt.savefig(fig_path, dpi=300)


def plot_ability_profiles(p, include_title=False, path=None):
Expand Down Expand Up @@ -233,7 +233,7 @@ def plot_elliptical_u(p, plot_MU=True, include_title=False, path=None):
return fig
else:
fig_path = os.path.join(path, "ellipse_v_CFE")
plt.savefig(fig_path)
plt.savefig(fig_path, dpi=300)


def plot_chi_n(p, include_title=False, path=None):
Expand All @@ -259,7 +259,7 @@ def plot_chi_n(p, include_title=False, path=None):
return fig
else:
fig_path = os.path.join(path, "chi_n_values")
plt.savefig(fig_path)
plt.savefig(fig_path, dpi=300)


def plot_fert_rates(
Expand Down Expand Up @@ -332,7 +332,7 @@ def plot_fert_rates(
# Save or return figure
if output_dir:
output_path = os.path.join(output_dir, "fert_rates")
plt.savefig(output_path)
plt.savefig(output_path, dpi=300)
plt.close()
else:
return fig
Expand Down Expand Up @@ -394,7 +394,7 @@ def plot_mort_rates_data(
np.hstack([infmort_rate, mort_rates_all[min_yr - 1 : max_yr]]),
)
plt.axvline(x=max_yr, color="red", linestyle="-", linewidth=1)
plt.grid(b=True, which="major", color="0.65", linestyle="-")
plt.grid(visible=True, which="major", color="0.65", linestyle="-")
# plt.title('Fitted mortality rate function by age ($rho_{s}$)',
# fontsize=20)
plt.xlabel(r"Age $s$")
Expand All @@ -411,7 +411,7 @@ def plot_mort_rates_data(
# Save or return figure
if output_dir:
output_path = os.path.join(output_dir, "mort_rates")
plt.savefig(output_path)
plt.savefig(output_path, dpi=300)
plt.close()
else:
return fig
Expand Down Expand Up @@ -455,7 +455,7 @@ def plot_omega_fixed(
# Save or return figure
if output_dir:
output_path = os.path.join(output_dir, "OrigVsFixSSpop")
plt.savefig(output_path)
plt.savefig(output_path, dpi=300)
plt.close()
else:
return fig
Expand Down Expand Up @@ -496,7 +496,7 @@ def plot_imm_fixed(
# Save or return figure
if output_dir:
output_path = os.path.join(output_dir, "OrigVsAdjImm")
plt.savefig(output_path)
plt.savefig(output_path, dpi=300)
plt.close()
else:
return fig
Expand Down Expand Up @@ -562,7 +562,7 @@ def plot_population_path(
# Save or return figure
if output_dir:
output_path = os.path.join(output_dir, "PopDistPath")
plt.savefig(output_path)
plt.savefig(output_path, dpi=300)
plt.close()
else:
return fig
Expand Down Expand Up @@ -853,7 +853,7 @@ def txfunc_sse_plot(age_vec, sse_mat, start_year, varstr, output_dir, round):
plt.ylabel(r"SSE")
graphname = "SSE_" + varstr + "_Round" + str(round)
output_path = os.path.join(output_dir, graphname)
plt.savefig(output_path, bbox_inches="tight")
plt.savefig(output_path, bbox_inches="tight", dpi=300)
plt.close()


Expand Down Expand Up @@ -889,15 +889,15 @@ def plot_income_data(
plt.plot(ages, emat)
filename = "ability_2D_lev" + filesuffix
fullpath = os.path.join(output_dir, filename)
plt.savefig(fullpath)
plt.savefig(fullpath, dpi=300)
plt.close()

# Plot of 2D, J=1 in logs
plt.figure()
plt.plot(ages, np.log(emat))
filename = "ability_2D_log" + filesuffix
fullpath = os.path.join(output_dir, filename)
plt.savefig(fullpath)
plt.savefig(fullpath, dpi=300)
plt.close()
else:
# Plot of 3D, J>1 in levels
Expand All @@ -910,7 +910,7 @@ def plot_income_data(
ax10.set_zlabel(r"ability $e_{j,s}$")
filename = "ability_3D_lev" + filesuffix
fullpath = os.path.join(output_dir, filename)
plt.savefig(fullpath)
plt.savefig(fullpath, dpi=300)
plt.close()

# Plot of 3D, J>1 in logs
Expand All @@ -928,7 +928,7 @@ def plot_income_data(
ax11.set_zlabel(r"log ability $log(e_{j,s})$")
filename = "ability_3D_log" + filesuffix
fullpath = os.path.join(output_dir, filename)
plt.savefig(fullpath)
plt.savefig(fullpath, dpi=300)
plt.close()

if J <= 10: # Restricted because of line and marker types
Expand Down Expand Up @@ -976,7 +976,7 @@ def plot_income_data(
ax.set_ylabel(r"log ability $log(e_{j,s})$")
filename = "ability_2D_log" + filesuffix
fullpath = os.path.join(output_dir, filename)
plt.savefig(fullpath)
plt.savefig(fullpath, dpi=300)
plt.close()
else:
if J <= 10: # Restricted because of line and marker types
Expand Down Expand Up @@ -1189,4 +1189,4 @@ def wm(x):
if path is None:
return fig
else:
plt.savefig(path)
plt.savefig(path, dpi=300)

0 comments on commit 21c2a5c

Please sign in to comment.