From 527c074447bbbcbb8b6fb620f69ee0a94c50d5bd Mon Sep 17 00:00:00 2001 From: EhsanGharibNezhad Date: Mon, 6 Nov 2023 17:32:21 -0800 Subject: [PATCH] minor update of StatVisAnalyzer.py --- TelescopeML/StatVisAnalyzer.py | 61 +++++++++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/TelescopeML/StatVisAnalyzer.py b/TelescopeML/StatVisAnalyzer.py index 953389b8..378ea8dd 100644 --- a/TelescopeML/StatVisAnalyzer.py +++ b/TelescopeML/StatVisAnalyzer.py @@ -14,6 +14,7 @@ from sklearn.metrics import r2_score, mean_squared_error from scipy.interpolate import RegularGridInterpolator from scipy.stats import chi2 +import os import pprint @@ -1340,4 +1341,62 @@ def plot_scatter_x_y (x, y, p.grid.grid_line_alpha = 0.5 # Show the plot - show(p) \ No newline at end of file + show(p) + + +def plot_filtered_dataframe(dataset, filter_bounds, feature_to_plot, title_label, wl_synthetic, __reference_data__): + """ + Plot a DataFrame with a single x-axis (using column names) and multiple y-axes. + + Parameters: + - df (pd.DataFrame): DataFrame containing the data to be plotted. + """ + + filtered_df = dataset.copy() + for feature, bounds in filter_bounds.items(): + lower_bound, upper_bound = bounds + filtered_df = filtered_df[(filtered_df[feature] >= lower_bound) & (filtered_df[feature] <= upper_bound)] + + filtered_df2 = filtered_df.sort_values(feature_to_plot, ascending=False).iloc[::1, 4:-1][::-1] + + fig, ax = plt.subplots(figsize=(10, 3)) + + x = filtered_df2.columns + df_transposed = filtered_df2.T # Transpose the DataFrame + + # Define a color palette + num_colors = len(df_transposed.columns) # Number of colors needed (excluding x-axis) + colors = sns.color_palette('magma', num_colors) + + for i, col in enumerate(df_transposed.columns): + # print(col) + if col != 'x': # Skip the x-axis column + ax.semilogy(wl_synthetic, df_transposed[col], + # label=data[col][:4].values, + color=colors[i], alpha=0.7) + + # print(filtered_data.T[col][:4].values[0]) + ax.set_xlabel('Features (Wavelength [$\mu$m])') + ax.set_ylabel(r'F$_{\nu}$ [erg/cm$^2$/s/Hz]') + dict_features = {'temperature': 'Effective Temperature', 'gravity': 'Gravity', 'metallicity': 'Metallicity', + 'c_o_ratio': 'Carbon-to-oxygen ratio'} + ax.set_title(dict_features[feature_to_plot] + " " + title_label) + # ax.legend() + + # Get the minimum and maximum values from the data + # vmin = df_transposed.values.min() + # vmax = df_transposed.values.max() + + # Add colorbar + cmap = sns.color_palette('magma', as_cmap=True) + cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=cmap, + norm=plt.Normalize(vmin=filter_bounds[feature_to_plot][0], + vmax=filter_bounds[feature_to_plot][1])), ax=ax) + # dict_features2 = {'temperature':'T [K]', 'gravity':'log$g$', 'metallicity':'[M/H]', 'c_o_ratio':'C/O ratio'} + dict_features = {'temperature': 'T$_{eff}$ [K]', 'gravity': 'log$g$', 'metallicity': '[M/H]', 'c_o_ratio': 'C/O'} + cbar.set_label(dict_features[feature_to_plot]) + + plt.savefig(os.path.join(__reference_data__, 'figures', feature_to_plot + "_trainin_examples.pdf"), dpi=500, + bbox_inches='tight') + + plt.show()