Skip to content

Commit

Permalink
minor update of StatVisAnalyzer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
EhsanGharibNezhad committed Nov 7, 2023
1 parent 42cd712 commit 527c074
Showing 1 changed file with 60 additions and 1 deletion.
61 changes: 60 additions & 1 deletion TelescopeML/StatVisAnalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sklearn.metrics import r2_score, mean_squared_error
from scipy.interpolate import RegularGridInterpolator
from scipy.stats import chi2
import os

import pprint

Expand Down Expand Up @@ -1340,4 +1341,62 @@ def plot_scatter_x_y (x, y,
p.grid.grid_line_alpha = 0.5

# Show the plot
show(p)
show(p)


def plot_filtered_dataframe(dataset, filter_bounds, feature_to_plot, title_label, wl_synthetic, __reference_data__):
"""
Plot a DataFrame with a single x-axis (using column names) and multiple y-axes.
Parameters:
- df (pd.DataFrame): DataFrame containing the data to be plotted.
"""

filtered_df = dataset.copy()
for feature, bounds in filter_bounds.items():
lower_bound, upper_bound = bounds
filtered_df = filtered_df[(filtered_df[feature] >= lower_bound) & (filtered_df[feature] <= upper_bound)]

filtered_df2 = filtered_df.sort_values(feature_to_plot, ascending=False).iloc[::1, 4:-1][::-1]

fig, ax = plt.subplots(figsize=(10, 3))

x = filtered_df2.columns
df_transposed = filtered_df2.T # Transpose the DataFrame

# Define a color palette
num_colors = len(df_transposed.columns) # Number of colors needed (excluding x-axis)
colors = sns.color_palette('magma', num_colors)

for i, col in enumerate(df_transposed.columns):
# print(col)
if col != 'x': # Skip the x-axis column
ax.semilogy(wl_synthetic, df_transposed[col],
# label=data[col][:4].values,
color=colors[i], alpha=0.7)

# print(filtered_data.T[col][:4].values[0])
ax.set_xlabel('Features (Wavelength [$\mu$m])')
ax.set_ylabel(r'F$_{\nu}$ [erg/cm$^2$/s/Hz]')
dict_features = {'temperature': 'Effective Temperature', 'gravity': 'Gravity', 'metallicity': 'Metallicity',
'c_o_ratio': 'Carbon-to-oxygen ratio'}
ax.set_title(dict_features[feature_to_plot] + " " + title_label)
# ax.legend()

# Get the minimum and maximum values from the data
# vmin = df_transposed.values.min()
# vmax = df_transposed.values.max()

# Add colorbar
cmap = sns.color_palette('magma', as_cmap=True)
cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=cmap,
norm=plt.Normalize(vmin=filter_bounds[feature_to_plot][0],
vmax=filter_bounds[feature_to_plot][1])), ax=ax)
# dict_features2 = {'temperature':'T [K]', 'gravity':'log$g$', 'metallicity':'[M/H]', 'c_o_ratio':'C/O ratio'}
dict_features = {'temperature': 'T$_{eff}$ [K]', 'gravity': 'log$g$', 'metallicity': '[M/H]', 'c_o_ratio': 'C/O'}
cbar.set_label(dict_features[feature_to_plot])

plt.savefig(os.path.join(__reference_data__, 'figures', feature_to_plot + "_trainin_examples.pdf"), dpi=500,
bbox_inches='tight')

plt.show()

0 comments on commit 527c074

Please sign in to comment.