Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add validation plot and score for webui #28

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 112 additions & 15 deletions aide/webui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,24 +158,35 @@ def handle_file_upload(self):
Returns:
list: List of uploaded or example files.
"""
if st.button(
"Load Example Experiment", type="primary", use_container_width=True
):
st.session_state.example_files = self.load_example_files()
# Only show file uploader if no example files are loaded
if not st.session_state.get("example_files"):
uploaded_files = st.file_uploader(
"Upload Data Files",
accept_multiple_files=True,
type=["csv", "txt", "json", "md"],
label_visibility="collapsed",
)

if uploaded_files:
st.session_state.pop(
"example_files", None
) # Remove example files if any
return uploaded_files

# Only show example button if no files are uploaded
if st.button(
"Load Example Experiment", type="primary", use_container_width=True
):
st.session_state.example_files = self.load_example_files()

if st.session_state.get("example_files"):
st.info("Example files loaded! Click 'Run AIDE' to proceed.")
with st.expander("View Loaded Files", expanded=False):
for file in st.session_state.example_files:
st.text(f"📄 {file['name']}")
uploaded_files = st.session_state.example_files
else:
uploaded_files = st.file_uploader(
"Upload Data Files",
accept_multiple_files=True,
type=["csv", "txt", "json", "md"],
)
return uploaded_files
return st.session_state.example_files

return [] # Return empty list if no files are uploaded or loaded

def handle_user_inputs(self):
"""
Expand All @@ -187,12 +198,12 @@ def handle_user_inputs(self):
goal_text = st.text_area(
"Goal",
value=st.session_state.get("goal", ""),
placeholder="Example: Predict house prices",
placeholder="Example: Predict the sales price for each house",
)
eval_text = st.text_area(
"Evaluation Criteria",
value=st.session_state.get("eval", ""),
placeholder="Example: Use RMSE metric",
placeholder="Example: Use the RMSE metric between the logarithm of the predicted and observed values.",
)
num_steps = st.slider(
"Number of Steps",
Expand Down Expand Up @@ -450,7 +461,16 @@ def render_results_section(self):
st.header("Results")
if st.session_state.get("results"):
results = st.session_state.results
tabs = st.tabs(["Tree Visualization", "Best Solution", "Config", "Journal"])

tabs = st.tabs(
[
"Tree Visualization",
"Best Solution",
"Config",
"Journal",
"Validation Plot",
]
)

with tabs[0]:
self.render_tree_visualization(results)
Expand All @@ -460,6 +480,12 @@ def render_results_section(self):
self.render_config(results)
with tabs[3]:
self.render_journal(results)
with tabs[4]:
# Display best score before the plot
best_metric = self.get_best_metric(results)
if best_metric is not None:
st.metric("Best Validation Score", f"{best_metric:.4f}")
self.render_validation_plot(results)
else:
st.info("No results to display. Please run an experiment.")

Expand Down Expand Up @@ -529,6 +555,77 @@ def render_journal(results):
else:
st.info("No journal available.")

@staticmethod
def get_best_metric(results):
"""
Extract the best validation metric from results.
"""
try:
journal_data = json.loads(results["journal"])
metrics = []
for node in journal_data:
if node["metric"] is not None:
try:
# Convert string metric to float
metric_value = float(node["metric"])
metrics.append(metric_value)
except (ValueError, TypeError):
continue
return max(metrics) if metrics else None
except (json.JSONDecodeError, KeyError):
return None

@staticmethod
def render_validation_plot(results):
"""
Render the validation score plot.
"""
try:
journal_data = json.loads(results["journal"])
steps = []
metrics = []

for node in journal_data:
if node["metric"] is not None and node["metric"].lower() != "none":
try:
metric_value = float(node["metric"])
steps.append(node["step"])
metrics.append(metric_value)
except (ValueError, TypeError):
continue

if metrics:
import plotly.graph_objects as go

fig = go.Figure()
fig.add_trace(
go.Scatter(
x=steps,
y=metrics,
mode="lines+markers",
name="Validation Score",
line=dict(color="#F04370"),
marker=dict(color="#F04370"),
)
)

fig.update_layout(
title="Validation Score Progress",
xaxis_title="Step",
yaxis_title="Validation Score",
template="plotly_white",
hovermode="x unified",
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
)

st.plotly_chart(fig, use_container_width=True)
else:
st.info("No validation metrics available to plot.")

except (json.JSONDecodeError, KeyError):
st.error("Could not parse validation metrics data.")


if __name__ == "__main__":
app = WebUI()
Expand Down
2 changes: 1 addition & 1 deletion aide/webui/style.css
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* Main colors */
:root {
--background: #F2F0E7;
--background-shaded: #EBE8DD;
--background-shaded: #FFFFFF;
--card: #FFFFFF;
--primary: #0D0F18;
--accent: #F04370;
Expand Down