From cde3a7aa49d434af4d8be5bcc83490de3aaffc48 Mon Sep 17 00:00:00 2001 From: Dixing Xu Date: Wed, 27 Nov 2024 22:53:07 +0800 Subject: [PATCH] :recycle: Refactor webui to live render results --- aide/webui/app.py | 299 ++++++++++------------------------------------ 1 file changed, 62 insertions(+), 237 deletions(-) diff --git a/aide/webui/app.py b/aide/webui/app.py index 68f9d73..9a6f4a4 100644 --- a/aide/webui/app.py +++ b/aide/webui/app.py @@ -100,8 +100,6 @@ def run(self): input_col, results_col = st.columns([1, 3]) with input_col: self.render_input_section(results_col) - with results_col: - self.render_results_section() def render_sidebar(self): """ @@ -273,17 +271,46 @@ def run_aide(self, files, goal_text, eval_text, num_steps, results_col): return None experiment = self.initialize_experiment(input_dir, goal_text, eval_text) - placeholders = self.create_results_placeholders(results_col, experiment) + + # Create separate placeholders for progress and config + progress_placeholder = results_col.empty() + config_placeholder = results_col.empty() + results_placeholder = results_col.empty() for step in range(num_steps): st.session_state.current_step = step + 1 progress = (step + 1) / num_steps - self.update_results_placeholders(placeholders, progress) + + # Update progress + with progress_placeholder.container(): + st.markdown( + f"### 🔥 Running Step {st.session_state.current_step}/{st.session_state.total_steps}" + ) + st.progress(progress) + + # Show config only for first step + if step == 0: + with config_placeholder.container(): + st.markdown("### 📋 Configuration") + st.code(OmegaConf.to_yaml(experiment.cfg), language="yaml") + experiment.run(steps=1) - self.clear_run_state(placeholders) + # Show results + with results_placeholder.container(): + self.render_live_results(experiment) - return self.collect_results(experiment) + # Clear config after first step + if step == 0: + config_placeholder.empty() + + # Clear progress after all steps + progress_placeholder.empty() + + # Update session state + st.session_state.is_running = False + st.session_state.results = self.collect_results(experiment) + return st.session_state.results except Exception as e: st.session_state.is_running = False @@ -355,70 +382,6 @@ def initialize_experiment(input_dir, goal_text, eval_text): experiment = Experiment(data_dir=str(input_dir), goal=goal_text, eval=eval_text) return experiment - @staticmethod - def create_results_placeholders(results_col, experiment): - """ - Create placeholders in the results column for dynamic content. - - Args: - results_col (st.delta_generator.DeltaGenerator): The results column. - experiment (Experiment): The Experiment object. - - Returns: - dict: Dictionary of placeholders. - """ - with results_col: - status_placeholder = st.empty() - step_placeholder = st.empty() - config_title_placeholder = st.empty() - config_placeholder = st.empty() - progress_placeholder = st.empty() - - step_placeholder.markdown( - f"### 🔥 Running Step {st.session_state.current_step}/{st.session_state.total_steps}" - ) - config_title_placeholder.markdown("### 📋 Configuration") - config_placeholder.code(OmegaConf.to_yaml(experiment.cfg), language="yaml") - progress_placeholder.progress(0) - - placeholders = { - "status": status_placeholder, - "step": step_placeholder, - "config_title": config_title_placeholder, - "config": config_placeholder, - "progress": progress_placeholder, - } - return placeholders - - @staticmethod - def update_results_placeholders(placeholders, progress): - """ - Update the placeholders with the current progress. - - Args: - placeholders (dict): Dictionary of placeholders. - progress (float): Current progress value. - """ - placeholders["step"].markdown( - f"### 🔥 Running Step {st.session_state.current_step}/{st.session_state.total_steps}" - ) - placeholders["progress"].progress(progress) - - @staticmethod - def clear_run_state(placeholders): - """ - Clear the running state and placeholders after the experiment. - - Args: - placeholders (dict): Dictionary of placeholders. - """ - st.session_state.is_running = False - placeholders["status"].empty() - placeholders["step"].empty() - placeholders["config_title"].empty() - placeholders["config"].empty() - placeholders["progress"].empty() - @staticmethod def collect_results(experiment): """ @@ -454,177 +417,39 @@ def collect_results(experiment): } return results - def render_results_section(self): + def render_live_results(self, experiment): """ - Render the results section with tabs for different outputs. - """ - st.header("Results") - if st.session_state.get("results"): - results = st.session_state.results - - tabs = st.tabs( - [ - "Tree Visualization", - "Best Solution", - "Config", - "Journal", - "Validation Plot", - ] - ) - - with tabs[0]: - self.render_tree_visualization(results) - with tabs[1]: - self.render_best_solution(results) - with tabs[2]: - self.render_config(results) - with tabs[3]: - self.render_journal(results) - with tabs[4]: - # Display best score before the plot - best_metric = self.get_best_metric(results) - if best_metric is not None: - st.metric("Best Validation Score", f"{best_metric:.4f}") - self.render_validation_plot(results) - else: - st.info("No results to display. Please run an experiment.") - - @staticmethod - def render_tree_visualization(results): - """ - Render the tree visualization from the experiment results. - - Args: - results (dict): The results dictionary containing paths and data. - """ - if "tree_path" in results: - tree_path = Path(results["tree_path"]) - logger.info(f"Loading tree visualization from: {tree_path}") - if tree_path.exists(): - with open(tree_path, "r", encoding="utf-8") as f: - html_content = f.read() - components.html(html_content, height=600, scrolling=True) - else: - st.error(f"Tree visualization file not found at: {tree_path}") - logger.error(f"Tree file not found at: {tree_path}") - else: - st.info("No tree visualization available for this run.") - - @staticmethod - def render_best_solution(results): - """ - Display the best solution code. - - Args: - results (dict): The results dictionary containing the solution. - """ - if "solution" in results: - solution_code = results["solution"] - st.code(solution_code, language="python") - else: - st.info("No solution available.") - - @staticmethod - def render_config(results): - """ - Display the configuration used in the experiment. + Render live results. Args: - results (dict): The results dictionary containing the config. - """ - if "config" in results: - st.code(results["config"], language="yaml") - else: - st.info("No configuration available.") - - @staticmethod - def render_journal(results): - """ - Display the experiment journal as JSON. - - Args: - results (dict): The results dictionary containing the journal. - """ - if "journal" in results: - try: - journal_data = json.loads(results["journal"]) - formatted_journal = json.dumps(journal_data, indent=2) - st.code(formatted_journal, language="json") - except json.JSONDecodeError: - st.code(results["journal"], language="json") - else: - st.info("No journal available.") - - @staticmethod - def get_best_metric(results): - """ - Extract the best validation metric from results. - """ - try: - journal_data = json.loads(results["journal"]) - metrics = [] - for node in journal_data: - if node["metric"] is not None: - try: - # Convert string metric to float - metric_value = float(node["metric"]) - metrics.append(metric_value) - except (ValueError, TypeError): - continue - return max(metrics) if metrics else None - except (json.JSONDecodeError, KeyError): - return None - - @staticmethod - def render_validation_plot(results): - """ - Render the validation score plot. - """ - try: - journal_data = json.loads(results["journal"]) - steps = [] - metrics = [] - - for node in journal_data: - if node["metric"] is not None and node["metric"].lower() != "none": - try: - metric_value = float(node["metric"]) - steps.append(node["step"]) - metrics.append(metric_value) - except (ValueError, TypeError): - continue - - if metrics: - import plotly.graph_objects as go - - fig = go.Figure() - fig.add_trace( - go.Scatter( - x=steps, - y=metrics, - mode="lines+markers", - name="Validation Score", - line=dict(color="#F04370"), - marker=dict(color="#F04370"), - ) - ) - - fig.update_layout( - title="Validation Score Progress", - xaxis_title="Step", - yaxis_title="Validation Score", - template="plotly_white", - hovermode="x unified", - plot_bgcolor="rgba(0,0,0,0)", - paper_bgcolor="rgba(0,0,0,0)", - ) - - st.plotly_chart(fig, use_container_width=True) - else: - st.info("No validation metrics available to plot.") + experiment (Experiment): The Experiment object + """ + results = self.collect_results(experiment) + + # Create tabs for different result views + tabs = st.tabs( + [ + "Tree Visualization", + "Best Solution", + "Config", + "Journal", + "Validation Plot", + ] + ) - except (json.JSONDecodeError, KeyError): - st.error("Could not parse validation metrics data.") + with tabs[0]: + self.render_tree_visualization(results) + with tabs[1]: + self.render_best_solution(results) + with tabs[2]: + self.render_config(results) + with tabs[3]: + self.render_journal(results) + with tabs[4]: + best_metric = self.get_best_metric(results) + if best_metric is not None: + st.metric("Best Validation Score", f"{best_metric:.4f}") + self.render_validation_plot(results, step=st.session_state.current_step) if __name__ == "__main__":