Skip to content

Commit

Permalink
♻️ Refactor webui to live render results
Browse files Browse the repository at this point in the history
  • Loading branch information
dexhunter committed Nov 27, 2024
1 parent f3092ac commit cde3a7a
Showing 1 changed file with 62 additions and 237 deletions.
299 changes: 62 additions & 237 deletions aide/webui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ def run(self):
input_col, results_col = st.columns([1, 3])
with input_col:
self.render_input_section(results_col)
with results_col:
self.render_results_section()

def render_sidebar(self):
"""
Expand Down Expand Up @@ -273,17 +271,46 @@ def run_aide(self, files, goal_text, eval_text, num_steps, results_col):
return None

experiment = self.initialize_experiment(input_dir, goal_text, eval_text)
placeholders = self.create_results_placeholders(results_col, experiment)

# Create separate placeholders for progress and config
progress_placeholder = results_col.empty()
config_placeholder = results_col.empty()
results_placeholder = results_col.empty()

for step in range(num_steps):
st.session_state.current_step = step + 1
progress = (step + 1) / num_steps
self.update_results_placeholders(placeholders, progress)

# Update progress
with progress_placeholder.container():
st.markdown(
f"### 🔥 Running Step {st.session_state.current_step}/{st.session_state.total_steps}"
)
st.progress(progress)

# Show config only for first step
if step == 0:
with config_placeholder.container():
st.markdown("### 📋 Configuration")
st.code(OmegaConf.to_yaml(experiment.cfg), language="yaml")

experiment.run(steps=1)

self.clear_run_state(placeholders)
# Show results
with results_placeholder.container():
self.render_live_results(experiment)

return self.collect_results(experiment)
# Clear config after first step
if step == 0:
config_placeholder.empty()

# Clear progress after all steps
progress_placeholder.empty()

# Update session state
st.session_state.is_running = False
st.session_state.results = self.collect_results(experiment)
return st.session_state.results

except Exception as e:
st.session_state.is_running = False
Expand Down Expand Up @@ -355,70 +382,6 @@ def initialize_experiment(input_dir, goal_text, eval_text):
experiment = Experiment(data_dir=str(input_dir), goal=goal_text, eval=eval_text)
return experiment

@staticmethod
def create_results_placeholders(results_col, experiment):
"""
Create placeholders in the results column for dynamic content.
Args:
results_col (st.delta_generator.DeltaGenerator): The results column.
experiment (Experiment): The Experiment object.
Returns:
dict: Dictionary of placeholders.
"""
with results_col:
status_placeholder = st.empty()
step_placeholder = st.empty()
config_title_placeholder = st.empty()
config_placeholder = st.empty()
progress_placeholder = st.empty()

step_placeholder.markdown(
f"### 🔥 Running Step {st.session_state.current_step}/{st.session_state.total_steps}"
)
config_title_placeholder.markdown("### 📋 Configuration")
config_placeholder.code(OmegaConf.to_yaml(experiment.cfg), language="yaml")
progress_placeholder.progress(0)

placeholders = {
"status": status_placeholder,
"step": step_placeholder,
"config_title": config_title_placeholder,
"config": config_placeholder,
"progress": progress_placeholder,
}
return placeholders

@staticmethod
def update_results_placeholders(placeholders, progress):
"""
Update the placeholders with the current progress.
Args:
placeholders (dict): Dictionary of placeholders.
progress (float): Current progress value.
"""
placeholders["step"].markdown(
f"### 🔥 Running Step {st.session_state.current_step}/{st.session_state.total_steps}"
)
placeholders["progress"].progress(progress)

@staticmethod
def clear_run_state(placeholders):
"""
Clear the running state and placeholders after the experiment.
Args:
placeholders (dict): Dictionary of placeholders.
"""
st.session_state.is_running = False
placeholders["status"].empty()
placeholders["step"].empty()
placeholders["config_title"].empty()
placeholders["config"].empty()
placeholders["progress"].empty()

@staticmethod
def collect_results(experiment):
"""
Expand Down Expand Up @@ -454,177 +417,39 @@ def collect_results(experiment):
}
return results

def render_results_section(self):
def render_live_results(self, experiment):
"""
Render the results section with tabs for different outputs.
"""
st.header("Results")
if st.session_state.get("results"):
results = st.session_state.results

tabs = st.tabs(
[
"Tree Visualization",
"Best Solution",
"Config",
"Journal",
"Validation Plot",
]
)

with tabs[0]:
self.render_tree_visualization(results)
with tabs[1]:
self.render_best_solution(results)
with tabs[2]:
self.render_config(results)
with tabs[3]:
self.render_journal(results)
with tabs[4]:
# Display best score before the plot
best_metric = self.get_best_metric(results)
if best_metric is not None:
st.metric("Best Validation Score", f"{best_metric:.4f}")
self.render_validation_plot(results)
else:
st.info("No results to display. Please run an experiment.")

@staticmethod
def render_tree_visualization(results):
"""
Render the tree visualization from the experiment results.
Args:
results (dict): The results dictionary containing paths and data.
"""
if "tree_path" in results:
tree_path = Path(results["tree_path"])
logger.info(f"Loading tree visualization from: {tree_path}")
if tree_path.exists():
with open(tree_path, "r", encoding="utf-8") as f:
html_content = f.read()
components.html(html_content, height=600, scrolling=True)
else:
st.error(f"Tree visualization file not found at: {tree_path}")
logger.error(f"Tree file not found at: {tree_path}")
else:
st.info("No tree visualization available for this run.")

@staticmethod
def render_best_solution(results):
"""
Display the best solution code.
Args:
results (dict): The results dictionary containing the solution.
"""
if "solution" in results:
solution_code = results["solution"]
st.code(solution_code, language="python")
else:
st.info("No solution available.")

@staticmethod
def render_config(results):
"""
Display the configuration used in the experiment.
Render live results.
Args:
results (dict): The results dictionary containing the config.
"""
if "config" in results:
st.code(results["config"], language="yaml")
else:
st.info("No configuration available.")

@staticmethod
def render_journal(results):
"""
Display the experiment journal as JSON.
Args:
results (dict): The results dictionary containing the journal.
"""
if "journal" in results:
try:
journal_data = json.loads(results["journal"])
formatted_journal = json.dumps(journal_data, indent=2)
st.code(formatted_journal, language="json")
except json.JSONDecodeError:
st.code(results["journal"], language="json")
else:
st.info("No journal available.")

@staticmethod
def get_best_metric(results):
"""
Extract the best validation metric from results.
"""
try:
journal_data = json.loads(results["journal"])
metrics = []
for node in journal_data:
if node["metric"] is not None:
try:
# Convert string metric to float
metric_value = float(node["metric"])
metrics.append(metric_value)
except (ValueError, TypeError):
continue
return max(metrics) if metrics else None
except (json.JSONDecodeError, KeyError):
return None

@staticmethod
def render_validation_plot(results):
"""
Render the validation score plot.
"""
try:
journal_data = json.loads(results["journal"])
steps = []
metrics = []

for node in journal_data:
if node["metric"] is not None and node["metric"].lower() != "none":
try:
metric_value = float(node["metric"])
steps.append(node["step"])
metrics.append(metric_value)
except (ValueError, TypeError):
continue

if metrics:
import plotly.graph_objects as go

fig = go.Figure()
fig.add_trace(
go.Scatter(
x=steps,
y=metrics,
mode="lines+markers",
name="Validation Score",
line=dict(color="#F04370"),
marker=dict(color="#F04370"),
)
)

fig.update_layout(
title="Validation Score Progress",
xaxis_title="Step",
yaxis_title="Validation Score",
template="plotly_white",
hovermode="x unified",
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
)

st.plotly_chart(fig, use_container_width=True)
else:
st.info("No validation metrics available to plot.")
experiment (Experiment): The Experiment object
"""
results = self.collect_results(experiment)

# Create tabs for different result views
tabs = st.tabs(
[
"Tree Visualization",
"Best Solution",
"Config",
"Journal",
"Validation Plot",
]
)

except (json.JSONDecodeError, KeyError):
st.error("Could not parse validation metrics data.")
with tabs[0]:
self.render_tree_visualization(results)
with tabs[1]:
self.render_best_solution(results)
with tabs[2]:
self.render_config(results)
with tabs[3]:
self.render_journal(results)
with tabs[4]:
best_metric = self.get_best_metric(results)
if best_metric is not None:
st.metric("Best Validation Score", f"{best_metric:.4f}")
self.render_validation_plot(results, step=st.session_state.current_step)


if __name__ == "__main__":
Expand Down

0 comments on commit cde3a7a

Please sign in to comment.