Skip to content

Commit

Permalink
Merge e14da5f into ab20341
Browse files Browse the repository at this point in the history
  • Loading branch information
gtdang authored Dec 21, 2023
2 parents ab20341 + e14da5f commit 2d4bd76
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 55 deletions.
79 changes: 56 additions & 23 deletions hnn_core/gui/_viz_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import copy
import io
from functools import partial
from functools import partial, wraps

import matplotlib
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -99,6 +99,35 @@ def plot_type_coupled_change(new_plot_type, target_data_selection):
target_data_selection.disabled = False


def unlink_relink(attribute: str):
"""
Wrapper function to unlink widgets to perform edits and re-link them on
completion. Used as a decorator on class methods. The class must have an
attribute containing an ipywidgets/traitlets link object.
Parameters
----------
attribute: The class attribute containing link object of ipywidgets widgets
"""
def _unlink_relink(f):
@wraps(f)
def wrapper(self, *args, **kwargs):
# Unlink the widgets using the provided link object
link_attribute: link = getattr(self, attribute)
link_attribute.unlink()

# Call the original function
result = f(self, *args, **kwargs)

# Re-link the widgets using link.link()
link_attribute.link()

return result
return wrapper
return _unlink_relink


def _idx2figname(idx):
return f"Figure {idx}"

Expand Down Expand Up @@ -226,9 +255,7 @@ def _update_ax(fig, ax, single_simulation, sim_name, plot_type, plot_config):
def _static_rerender(widgets, fig, fig_idx):
logger.debug('_static_re_render is called')
figs_tabs = widgets['figs_tabs']
titles = [
figs_tabs.get_title(idx) for idx in range(len(figs_tabs.children))
]
titles = figs_tabs.titles
fig_tab_idx = titles.index(_idx2figname(fig_idx))
fig_output = widgets['figs_tabs'].children[fig_tab_idx]
fig_output.clear_output()
Expand Down Expand Up @@ -501,18 +528,25 @@ def _on_plot_type_change(new_plot_type):
def _close_figure(b, widgets, data, fig_idx):
fig_related_widgets = [widgets['figs_tabs'], widgets['axes_config_tabs']]
for w_idx, tab in enumerate(fig_related_widgets):
# Get tab object's list of children and their titles
tab_children = list(tab.children)
titles = [tab.get_title(idx) for idx in range(len(tab.children))]
titles = list(tab.titles)
# Get the index based on the title
tab_idx = titles.index(_idx2figname(fig_idx))
# Remove the child and title specified
print(f"Del fig_idx={fig_idx}, fig_idx={fig_idx}")
del tab_children[tab_idx], titles[tab_idx]

tab.children = tuple(tab_children)
[tab.set_title(idx, title) for idx, title in enumerate(titles)]
tab_children.pop(tab_idx)
titles.pop(tab_idx)
# Reset children and titles of the tab object
tab.children = tab_children
tab.titles = titles

# If the figure tab group...
if w_idx == 0:
# Close figure and delete the data
plt.close(data['figs'][fig_idx])
del data['figs'][fig_idx]
data['figs'].pop(fig_idx)
# Redisplay the remaining children
n_tabs = len(tab.children)
for idx in range(n_tabs):
_fig_idx = _figname2idx(tab.get_title(idx))
Expand All @@ -522,10 +556,11 @@ def _close_figure(b, widgets, data, fig_idx):
with tab.children[idx]:
display(data['figs'][_fig_idx].canvas)

if n_tabs == 0:
widgets['figs_output'].clear_output()
with widgets['figs_output']:
display(Label(_fig_placeholder))
# If all children have been deleted display the placeholder
if n_tabs == 0:
widgets['figs_output'].clear_output()
with widgets['figs_output']:
display(Label(_fig_placeholder))


def _add_axes_controls(widgets, data, fig, axd):
Expand Down Expand Up @@ -565,8 +600,8 @@ def _add_figure(b, widgets, data, scale=0.95, dpi=96):
with widgets['figs_output']:
display(widgets['figs_tabs'])

widgets['figs_tabs'].children = widgets['figs_tabs'].children + (
fig_outputs, )
widgets['figs_tabs'].children = \
[s for s in widgets['figs_tabs'].children] + [fig_outputs]
widgets['figs_tabs'].set_title(n_tabs, _idx2figname(fig_idx))

with fig_outputs:
Expand Down Expand Up @@ -627,7 +662,7 @@ def __init__(self, gui_data, viz_layout):
self.figs_tabs = Tab()
self.axes_config_tabs.selected_index = None
self.figs_tabs.selected_index = None
link(
self.figs_config_tab_link = link(
(self.axes_config_tabs, 'selected_index'),
(self.figs_tabs, 'selected_index'),
)
Expand Down Expand Up @@ -711,6 +746,7 @@ def compose(self):
])
return config_panel, fig_output_container

@unlink_relink(attribute='figs_config_tab_link')
def add_figure(self, b=None):
"""Add a figure and corresponding config tabs to the dashboard.
"""
Expand All @@ -729,7 +765,7 @@ def _simulate_switch_fig_template(self, template_name):

def _simulate_delete_figure(self, fig_name):
tab = self.axes_config_tabs
titles = [tab.get_title(idx) for idx in range(len(tab.children))]
titles = tab.titles
assert fig_name in titles
tab_idx = titles.index(fig_name)

Expand Down Expand Up @@ -764,16 +800,13 @@ def _simulate_edit_figure(self, fig_name, ax_name, simulation_name,
assert operation in ("plot", "clear")

tab = self.axes_config_tabs
titles = [tab.get_title(idx) for idx in range(len(tab.children))]
titles = tab.titles
assert fig_name in titles, "No such figure"
tab_idx = titles.index(fig_name)
self.axes_config_tabs.selected_index = tab_idx

ax_control_tabs = self.axes_config_tabs.children[tab_idx].children[1]
ax_titles = [
ax_control_tabs.get_title(idx)
for idx in range(len(ax_control_tabs.children))
]
ax_titles = ax_control_tabs.titles
assert ax_name in ax_titles, "No such axis"
ax_idx = ax_titles.index(ax_name)
ax_control_tabs.selected_index = ax_idx
Expand Down
54 changes: 24 additions & 30 deletions hnn_core/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import urllib.request
from collections import defaultdict
from pathlib import Path
from datetime import datetime
from IPython.display import IFrame, display
from ipywidgets import (HTML, Accordion, AppLayout, BoundedFloatText,
BoundedIntText, Button, Dropdown, FileUpload, VBox,
Expand Down Expand Up @@ -438,8 +439,7 @@ def compose(self, return_layout=True):
'Layer 2/3 Pyramidal', 'Layer 5 Pyramidal', 'Layer 2 Basket',
'Layer 5 Basket')
cell_connectivity = Accordion(children=connectivity_boxes)
for idx, connectivity_name in enumerate(connectivity_names):
cell_connectivity.set_title(idx, connectivity_name)
cell_connectivity.titles = [s for s in connectivity_names]

drive_selections = VBox([
self.add_drive_button, self.widget_drive_type_selection,
Expand Down Expand Up @@ -616,15 +616,14 @@ def _simulate_upload_drives(self, file_url):
self.load_drives_button.set_trait('value', uploaded_value)

def _simulate_left_tab_click(self, tab_title):
tab_index = None
# Get left tab group object
left_tab = self.app_layout.left_sidebar.children[0].children[0]
for idx in left_tab._titles.keys():
if tab_title == left_tab._titles[idx]:
tab_index = int(idx)
break
if tab_index is None:
raise ValueError("Incorrect tab title")
left_tab.selected_index = tab_index
# Check that the title is in the tab group
if tab_title in left_tab.titles:
# Simulate the user clicking on the tab
left_tab.selected_index = left_tab.titles.index(tab_title)
else:
raise ValueError("Tab title does not exist.")

def _simulate_make_figure(self,):
self._simulate_left_tab_click("Visualization")
Expand Down Expand Up @@ -655,16 +654,13 @@ def _prepare_upload_file_from_url(file_url):
for line in data:
content += line

return {
params_name: {
'metadata': {
'name': params_name,
'type': 'application/json',
'size': len(content),
},
'content': content
}
}
return [{
'name': params_name,
'type': 'application/json',
'size': len(content),
'content': content,
'last_modified': datetime.now()
}]


def create_expanded_button(description, button_style, layout, disabled=False,
Expand Down Expand Up @@ -1133,14 +1129,14 @@ def on_upload_data_change(change, data, viz_manager, log_out):
logger.info("Empty change")
return

key = list(change['new'].keys())[0]
data_dict = change['new'][0]

data_fname = change['new'][key]['metadata']['name'].rstrip('.txt')
data_fname = data_dict['name'].rstrip('.txt')
if data_fname in data['simulation_data'].keys():
logger.error(f"Found existing data: {data_fname}.")
return

ext_content = change['new'][key]['content']
ext_content = data_dict['content']
ext_content = codecs.decode(ext_content, encoding="utf-8")
with log_out:
data['simulation_data'][data_fname] = {'net': None, 'dpls': [
Expand All @@ -1163,10 +1159,9 @@ def on_upload_params_change(change, params, tstop, dt, log_out, drive_boxes,
logger.info("Empty change")
return
logger.info("Loading connectivity...")
key = list(change['new'].keys())[0]

params_fname = change['new'][key]['metadata']['name']
param_data = change['new'][key]['content']
param_dict = change['new'][0]
params_fname = param_dict['name']
param_data = param_dict['content']

param_data = codecs.decode(param_data, encoding="utf-8")

Expand All @@ -1191,9 +1186,8 @@ def on_upload_params_change(change, params, tstop, dt, log_out, drive_boxes,
layout)
else:
raise ValueError

change['owner'].set_trait('_counter', 0)
change['owner'].set_trait('value', {})
# Resets file counter to 0
change['owner'].set_trait('value', ([]))


def _init_network_from_widgets(params, dt, tstop, single_simulation_data,
Expand Down
53 changes: 52 additions & 1 deletion hnn_core/tests/test_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
import matplotlib.pyplot as plt
import numpy as np
import pytest
import traitlets

from hnn_core import Dipole, Network, Params
from hnn_core.gui import HNNGUI
from hnn_core.gui._viz_manager import _idx2figname, _no_overlay_plot_types
from hnn_core.gui._viz_manager import _idx2figname, _no_overlay_plot_types, \
unlink_relink
from hnn_core.gui.gui import _init_network_from_widgets
from hnn_core.network import pick_connection
from hnn_core.network_models import jones_2009_model
from hnn_core.parallel_backends import requires_mpi4py, requires_psutil
from IPython.display import IFrame
from ipywidgets import Tab, Text, link

matplotlib.use('agg')

Expand Down Expand Up @@ -416,3 +420,50 @@ def test_gui_adaptive_spectrogram():
for attr in dir(gui.viz_manager.figs[figid])]) is False
assert len(gui.viz_manager.figs[1].axes) == 2
plt.close('all')


def test_unlink_relink_widget():
"""Tests the unlinking and relinking of widgets decorator."""

# Create a basic version of the VizManager class
class MiniViz:
def __init__(self):
self.tab_group_1 = Tab()
self.tab_group_2 = Tab()
self.tab_link = link(
(self.tab_group_1, 'selected_index'),
(self.tab_group_2, 'selected_index'),
)

def add_child(self, to_add=1):
n_tabs = len(self.tab_group_2.children) + to_add
# Add figure tab and select latest tab
self.tab_group_1.children = \
[Text(f'Test{s}') for s in np.arange(n_tabs)]
self.tab_group_1.selected_index = n_tabs - 1

self.tab_group_2.children = \
[Text(f'Test{s}') for s in np.arange(n_tabs)]
self.tab_group_2.selected_index = n_tabs - 1

@unlink_relink(attribute='tab_link')
def add_child_decorated(self, to_add):
self.add_child(to_add)

# Check that widgets are linked.
# Error from tab groups momentarily having a different number of children
gui = MiniViz()
with pytest.raises(traitlets.TraitError, match='.*index out of bounds.*'):
gui.add_child(2)

# Check decorator unlinks and is able to make a change
gui = MiniViz()
gui.add_child_decorated(2)
assert len(gui.tab_group_1.children) == 2
assert gui.tab_group_1.selected_index == 1
assert len(gui.tab_group_2.children) == 2
assert gui.tab_group_2.selected_index == 1

# Check if the widgets are relinked, the selected index should be synced
gui.tab_group_1.selected_index = 0
assert gui.tab_group_2.selected_index == 0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def run(self):
'h5io'
],
extras_require={
'gui': ['ipywidgets <=7.7.1', 'ipympl<0.9', 'voila<=0.3.6'],
'gui': ['ipywidgets>=8.0.0', 'ipykernel', 'ipympl', 'voila'],
'opt': ['scikit-learn']
},
python_requires='>=3.8',
Expand Down

0 comments on commit 2d4bd76

Please sign in to comment.