diff --git a/ogcore/constants.py b/ogcore/constants.py index 68c8eb79b..533051873 100644 --- a/ogcore/constants.py +++ b/ogcore/constants.py @@ -16,7 +16,7 @@ "B": "Wealth ($B_t$)", "I_total": "Investment ($I_t$)", "K": "Capital Stock ($K_t$)", - "Y_vec": "GDP ($Y_t$)", + "Y_vec": "Output ($Y_t$)", "C_vec": "Consumption ($C_t$)", "L_vec": "Labor ($L_t$)", "K_vec": "Capital Stock ($K_t$)", diff --git a/ogcore/output_plots.py b/ogcore/output_plots.py index c275c5e52..bf7af7fc5 100644 --- a/ogcore/output_plots.py +++ b/ogcore/output_plots.py @@ -151,7 +151,7 @@ def plot_aggregates( plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncol=2) if path: fig_path1 = os.path.join(path) - plt.savefig(fig_path1, bbox_inches="tight") + plt.savefig(fig_path1, bbox_inches="tight", dpi=300) else: return fig1 plt.close() @@ -163,6 +163,7 @@ def plot_industry_aggregates( reform_tpi=None, reform_params=None, var_list=["Y_vec"], + ind_names_list=None, plot_type="pct_diff", num_years_to_plot=50, start_year=DEFAULT_START_YEAR, @@ -207,6 +208,11 @@ def plot_industry_aggregates( """ assert isinstance(start_year, (int, np.integer)) assert isinstance(num_years_to_plot, int) + dims = base_tpi[var_list[0]].shape[1] + if ind_names_list: + assert len(ind_names_list) == dims + else: + ind_names_list = [str(i) for i in range(dims)] # Make sure both runs cover same time period if reform_tpi: assert base_params.start_year == reform_params.start_year @@ -217,7 +223,11 @@ def plot_industry_aggregates( assert reform_tpi is not None fig1, ax1 = plt.subplots() for i, v in enumerate(var_list): - for m in range(base_params.M): + if len(var_list) == 1: + var_label = "" + else: + var_label = VAR_LABELS[v] + for m in range(dims): if plot_type == "pct_diff": plot_var = ( reform_tpi[v][:, m] - base_tpi[v][:, m] @@ -226,7 +236,7 @@ def plot_industry_aggregates( plt.plot( year_vec, plot_var[start_index : start_index + num_years_to_plot], - label=VAR_LABELS[v] + "for industry " + str(m), + label=var_label + " " + ind_names_list[m], ) elif plot_type == "diff": plot_var = reform_tpi[v][:, m] - base_tpi[v][:, m] @@ -234,7 +244,7 @@ def plot_industry_aggregates( plt.plot( year_vec, plot_var[start_index : start_index + num_years_to_plot], - label=VAR_LABELS[v] + "for industry " + str(m), + label=var_label + " " + ind_names_list[m], ) elif plot_type == "levels": plt.plot( @@ -242,10 +252,7 @@ def plot_industry_aggregates( base_tpi[v][ start_index : start_index + num_years_to_plot, m ], - label="Baseline " - + VAR_LABELS[v] - + "for industry " - + str(m), + label="Baseline " + var_label + " " + ind_names_list[m], ) if reform_tpi: plt.plot( @@ -253,10 +260,7 @@ def plot_industry_aggregates( reform_tpi[v][ start_index : start_index + num_years_to_plot, m ], - label="Reform " - + VAR_LABELS[v] - + "for industry " - + str(m), + label="Reform " + var_label + " " + ind_names_list[m], ) ylabel = r"Model Units" elif plot_type == "forecast": @@ -267,10 +271,7 @@ def plot_industry_aggregates( plt.plot( year_vec, plot_var_base, - label="Baseline " - + VAR_LABELS[v] - + "for industry " - + str(m), + label="Baseline " + var_label + " " + ind_names_list[m], ) # Plot change from baseline forecast pct_change = ( @@ -287,7 +288,7 @@ def plot_industry_aggregates( plt.plot( year_vec, plot_var_reform, - label="Reform " + VAR_LABELS[v] + "for industry " + str(m), + label="Reform " + var_label + " " + ind_names_list[m], ) # making units labels will not work if multiple variables # and they are in different units @@ -316,7 +317,7 @@ def plot_industry_aggregates( plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncol=2) if path: fig_path1 = os.path.join(path) - plt.savefig(fig_path1, bbox_inches="tight") + plt.savefig(fig_path1, bbox_inches="tight", dpi=300) else: return fig1 plt.close() @@ -380,7 +381,7 @@ def ss_3Dplot( if plot_title: plt.title(plot_title) if path: - plt.savefig(path) + plt.savefig(path, dpi=300) else: return plt @@ -502,7 +503,7 @@ def plot_gdp_ratio( plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncol=2) if path: fig_path1 = os.path.join(path) - plt.savefig(fig_path1, bbox_inches="tight") + plt.savefig(fig_path1, bbox_inches="tight", dpi=300) else: return fig1 plt.close() @@ -576,7 +577,7 @@ def ability_bar( plt.title(plot_title, fontsize=15) if path: fig_path1 = os.path.join(path) - plt.savefig(fig_path1, bbox_inches="tight") + plt.savefig(fig_path1, bbox_inches="tight", dpi=300) else: return fig plt.close() @@ -630,7 +631,7 @@ def ability_bar_ss( plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncol=2) if path: fig_path1 = os.path.join(path) - plt.savefig(fig_path1, bbox_inches="tight") + plt.savefig(fig_path1, bbox_inches="tight", dpi=300) else: return fig plt.close() @@ -717,7 +718,7 @@ def tpi_profiles( plt.title(plot_title, fontsize=15) if path: fig_path1 = os.path.join(path) - plt.savefig(fig_path1, bbox_inches="tight") + plt.savefig(fig_path1, bbox_inches="tight", dpi=300) else: return fig1 plt.close() @@ -797,7 +798,7 @@ def ss_profiles( plt.title(plot_title, fontsize=15) if path: fig_path1 = os.path.join(path) - plt.savefig(fig_path1, bbox_inches="tight") + plt.savefig(fig_path1, bbox_inches="tight", dpi=300) else: return fig1 plt.close() @@ -1166,7 +1167,7 @@ def inequality_plot( plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncol=2) if path: fig_path1 = os.path.join(path) - plt.savefig(fig_path1, bbox_inches="tight") + plt.savefig(fig_path1, bbox_inches="tight", dpi=300) else: return fig1 plt.close() diff --git a/ogcore/parameter_plots.py b/ogcore/parameter_plots.py index 86e61d7e6..8b43eb12d 100644 --- a/ogcore/parameter_plots.py +++ b/ogcore/parameter_plots.py @@ -37,7 +37,7 @@ def plot_imm_rates(p, year=DEFAULT_START_YEAR, include_title=False, path=None): return fig else: fig_path = os.path.join(path, "imm_rates_orig") - plt.savefig(fig_path) + plt.savefig(fig_path, dpi=300) def plot_mort_rates(p, include_title=False, path=None): @@ -66,7 +66,7 @@ def plot_mort_rates(p, include_title=False, path=None): return fig else: fig_path = os.path.join(path, "mortality_rates") - plt.savefig(fig_path) + plt.savefig(fig_path, dpi=300) def plot_pop_growth( @@ -106,7 +106,7 @@ def plot_pop_growth( return fig else: fig_path = os.path.join(path, "pop_growth_rates") - plt.savefig(fig_path) + plt.savefig(fig_path, dpi=300) def plot_population(p, years_to_plot=["SS"], include_title=False, path=None): @@ -145,7 +145,7 @@ def plot_population(p, years_to_plot=["SS"], include_title=False, path=None): return fig else: fig_path = os.path.join(path, "pop_distribution") - plt.savefig(fig_path) + plt.savefig(fig_path, dpi=300) def plot_ability_profiles(p, include_title=False, path=None): @@ -233,7 +233,7 @@ def plot_elliptical_u(p, plot_MU=True, include_title=False, path=None): return fig else: fig_path = os.path.join(path, "ellipse_v_CFE") - plt.savefig(fig_path) + plt.savefig(fig_path, dpi=300) def plot_chi_n(p, include_title=False, path=None): @@ -259,7 +259,7 @@ def plot_chi_n(p, include_title=False, path=None): return fig else: fig_path = os.path.join(path, "chi_n_values") - plt.savefig(fig_path) + plt.savefig(fig_path, dpi=300) def plot_fert_rates( @@ -332,7 +332,7 @@ def plot_fert_rates( # Save or return figure if output_dir: output_path = os.path.join(output_dir, "fert_rates") - plt.savefig(output_path) + plt.savefig(output_path, dpi=300) plt.close() else: return fig @@ -394,7 +394,7 @@ def plot_mort_rates_data( np.hstack([infmort_rate, mort_rates_all[min_yr - 1 : max_yr]]), ) plt.axvline(x=max_yr, color="red", linestyle="-", linewidth=1) - plt.grid(b=True, which="major", color="0.65", linestyle="-") + plt.grid(visible=True, which="major", color="0.65", linestyle="-") # plt.title('Fitted mortality rate function by age ($rho_{s}$)', # fontsize=20) plt.xlabel(r"Age $s$") @@ -411,7 +411,7 @@ def plot_mort_rates_data( # Save or return figure if output_dir: output_path = os.path.join(output_dir, "mort_rates") - plt.savefig(output_path) + plt.savefig(output_path, dpi=300) plt.close() else: return fig @@ -455,7 +455,7 @@ def plot_omega_fixed( # Save or return figure if output_dir: output_path = os.path.join(output_dir, "OrigVsFixSSpop") - plt.savefig(output_path) + plt.savefig(output_path, dpi=300) plt.close() else: return fig @@ -496,7 +496,7 @@ def plot_imm_fixed( # Save or return figure if output_dir: output_path = os.path.join(output_dir, "OrigVsAdjImm") - plt.savefig(output_path) + plt.savefig(output_path, dpi=300) plt.close() else: return fig @@ -562,7 +562,7 @@ def plot_population_path( # Save or return figure if output_dir: output_path = os.path.join(output_dir, "PopDistPath") - plt.savefig(output_path) + plt.savefig(output_path, dpi=300) plt.close() else: return fig @@ -853,7 +853,7 @@ def txfunc_sse_plot(age_vec, sse_mat, start_year, varstr, output_dir, round): plt.ylabel(r"SSE") graphname = "SSE_" + varstr + "_Round" + str(round) output_path = os.path.join(output_dir, graphname) - plt.savefig(output_path, bbox_inches="tight") + plt.savefig(output_path, bbox_inches="tight", dpi=300) plt.close() @@ -889,7 +889,7 @@ def plot_income_data( plt.plot(ages, emat) filename = "ability_2D_lev" + filesuffix fullpath = os.path.join(output_dir, filename) - plt.savefig(fullpath) + plt.savefig(fullpath, dpi=300) plt.close() # Plot of 2D, J=1 in logs @@ -897,7 +897,7 @@ def plot_income_data( plt.plot(ages, np.log(emat)) filename = "ability_2D_log" + filesuffix fullpath = os.path.join(output_dir, filename) - plt.savefig(fullpath) + plt.savefig(fullpath, dpi=300) plt.close() else: # Plot of 3D, J>1 in levels @@ -910,7 +910,7 @@ def plot_income_data( ax10.set_zlabel(r"ability $e_{j,s}$") filename = "ability_3D_lev" + filesuffix fullpath = os.path.join(output_dir, filename) - plt.savefig(fullpath) + plt.savefig(fullpath, dpi=300) plt.close() # Plot of 3D, J>1 in logs @@ -928,7 +928,7 @@ def plot_income_data( ax11.set_zlabel(r"log ability $log(e_{j,s})$") filename = "ability_3D_log" + filesuffix fullpath = os.path.join(output_dir, filename) - plt.savefig(fullpath) + plt.savefig(fullpath, dpi=300) plt.close() if J <= 10: # Restricted because of line and marker types @@ -976,7 +976,7 @@ def plot_income_data( ax.set_ylabel(r"log ability $log(e_{j,s})$") filename = "ability_2D_log" + filesuffix fullpath = os.path.join(output_dir, filename) - plt.savefig(fullpath) + plt.savefig(fullpath, dpi=300) plt.close() else: if J <= 10: # Restricted because of line and marker types @@ -1189,4 +1189,4 @@ def wm(x): if path is None: return fig else: - plt.savefig(path) + plt.savefig(path, dpi=300)