Skip to content

Commit

Permalink
[MRG] New widget to adjust default smoothing value in gui (#924)
Browse files Browse the repository at this point in the history
* new widget to adjust default smoothing value in gui

* set default smoothing to 0 for data

* update default smoothing in viz manager class when clicking run button

* updating code to allow more general specification of default visualization params

* update smoothing test

* fix comment typo

* Update hnn_core/gui/_viz_manager.py

Remove commented out code

Co-authored-by: Nicholas Tolley <[email protected]>

* remove commented out code; whoops!

* update test to check that the unadjusted default smoothing is the same everywhere per the suggestion from @ntolley

* rename vars to remove need for property, update comments

---------

Co-authored-by: Nicholas Tolley <[email protected]>
  • Loading branch information
dylansdaniels and ntolley authored Nov 22, 2024
1 parent 11e90c4 commit 8c1351e
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 11 deletions.
19 changes: 12 additions & 7 deletions hnn_core/gui/_viz_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,11 +549,12 @@ def _clear_axis(b, widgets, data, fig_idx, fig, ax, widgets_plot_type,
_dynamic_rerender(fig)


def _get_ax_control(widgets, data, fig_idx, fig, ax):
def _get_ax_control(widgets, data, fig_default_params, fig_idx, fig, ax):
analysis_style = {'description_width': '200px'}
layout = Layout(width="98%")
simulation_names = tuple(data['simulations'].keys())
sim_index = 0
default_smoothing = fig_default_params['default_smoothing']
if not simulation_names:
simulation_names = ("None",)
else:
Expand Down Expand Up @@ -610,7 +611,7 @@ def _get_ax_control(widgets, data, fig_idx, fig, ax):
style=analysis_style,
)
simulation_dipole_smooth = FloatText(
value=30,
value=default_smoothing,
description='Dipole Smooth Window (ms):',
disabled=False,
layout=layout,
Expand Down Expand Up @@ -761,12 +762,13 @@ def _close_figure(b, widgets, data, fig_idx):
display(Label(_fig_placeholder))


def _add_axes_controls(widgets, data, fig, axd):
def _add_axes_controls(widgets, data, fig_default_smoothing, fig, axd):
fig_idx = data['fig_idx']['idx']

controls = Tab()
children = [
_get_ax_control(widgets, data, fig_idx=fig_idx, fig=fig, ax=ax)
_get_ax_control(widgets, data, fig_default_smoothing, fig_idx=fig_idx,
fig=fig, ax=ax)
for ax_key, ax in axd.items()
]
controls.children = children
Expand All @@ -786,7 +788,8 @@ def _add_axes_controls(widgets, data, fig, axd):
widgets['axes_config_tabs'].set_title(n_tabs, _idx2figname(fig_idx))


def _add_figure(b, widgets, data, template_type, scale=0.95, dpi=96):
def _add_figure(b, widgets, data, fig_default_smoothing,
template_type, scale=0.95, dpi=96):
fig_idx = data['fig_idx']['idx']
viz_output_layout = data['visualization_output']
fig_outputs = Output()
Expand Down Expand Up @@ -818,7 +821,7 @@ def _add_figure(b, widgets, data, template_type, scale=0.95, dpi=96):
else:
display(fig.canvas)

_add_axes_controls(widgets, data, fig=fig, axd=axd)
_add_axes_controls(widgets, data, fig_default_smoothing, fig=fig, axd=axd)

data['figs'][fig_idx] = fig
widgets['figs_tabs'].selected_index = n_tabs
Expand Down Expand Up @@ -869,9 +872,10 @@ class _VizManager:
A dict of external simulation data object
"""

def __init__(self, gui_data, viz_layout):
def __init__(self, gui_data, viz_layout, fig_default_params):
plt.close("all")
self.viz_layout = viz_layout
self.fig_default_params = fig_default_params
self.use_ipympl = 'ipympl' in matplotlib.get_backend()

self.axes_config_output = Output()
Expand Down Expand Up @@ -1029,6 +1033,7 @@ def add_figure(self, b=None):
_add_figure(None,
self.widgets,
self.data,
self.fig_default_params,
template_type,
scale=0.97,
dpi=self.viz_layout['dpi'])
Expand Down
29 changes: 25 additions & 4 deletions hnn_core/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,15 @@ def __init__(self, theme_color="#802989",
# In-memory storage of all simulation and visualization related data
self.simulation_data = defaultdict(lambda: dict(net=None, dpls=list()))

# Default visualization params for figures
self.widget_default_smoothing = BoundedFloatText(
value=30.0, description='Smoothing:',
min=0.0, max=100.0, step=1.0, disabled=False)

self.fig_default_params = {
'default_smoothing': self.widget_default_smoothing.value
}

# Simulation parameters
self.widget_tstop = BoundedFloatText(
value=170, description='tstop (ms):', min=0, max=1e6, step=1,
Expand Down Expand Up @@ -476,7 +485,8 @@ def _init_ui_components(self):

self._log_out = Output()

self.viz_manager = _VizManager(self.data, self.layout)
self.viz_manager = _VizManager(self.data, self.layout,
self.fig_default_params)

# detailed configuration of backends
self._backend_config_out = Output()
Expand Down Expand Up @@ -565,6 +575,7 @@ def _run_button_clicked(b):
return run_button_clicked(
self.widget_simulation_name, self._log_out, self.drive_widgets,
self.data, self.widget_dt, self.widget_tstop,
self.fig_default_params, self.widget_default_smoothing,
self.widget_ntrials, self.widget_backend_selection,
self.widget_mpi_cmd, self.widget_n_jobs, self.params,
self._simulation_status_bar, self._simulation_status_contents,
Expand Down Expand Up @@ -669,8 +680,8 @@ def compose(self, return_layout=True):
simulation_box = VBox([
VBox([
self.widget_simulation_name, self.widget_tstop, self.widget_dt,
self.widget_ntrials, self.widget_backend_selection,
self._backend_config_out]),
self.widget_ntrials, self.widget_default_smoothing,
self.widget_backend_selection, self._backend_config_out]),
], layout=self.layout['config_box'])

connectivity_configuration = Tab()
Expand Down Expand Up @@ -1910,7 +1921,9 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data,


def run_button_clicked(widget_simulation_name, log_out, drive_widgets,
all_data, dt, tstop, ntrials, backend_selection,
all_data, dt, tstop,
fig_default_params, widget_default_smoothing,
ntrials, backend_selection,
mpi_cmd, n_jobs, params, simulation_status_bar,
simulation_status_contents, connectivity_textfields,
viz_manager, simulations_list_widget,
Expand Down Expand Up @@ -1960,6 +1973,14 @@ def run_button_clicked(widget_simulation_name, log_out, drive_widgets,
simulations_list_widget.value = sim_names[0]

viz_manager.reset_fig_config_tabs()

# update default_smoothing in gui based on widget
fig_default_params['default_smoothing'] = widget_default_smoothing.value

# change default smoothing in viz_manager to mirror gui
new_default_smoothing = fig_default_params['default_smoothing']
viz_manager.fig_default_params['default_smoothing'] = new_default_smoothing

viz_manager.add_figure()
fig_name = _idx2figname(viz_manager.data['fig_idx']['idx'] - 1)
ax_plots = [("ax0", "input histogram"), ("ax1", "current dipole")]
Expand Down
58 changes: 58 additions & 0 deletions hnn_core/tests/test_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,3 +1116,61 @@ def test_delete_single_drive(setup_gui):
'alpha_prox (proximal)',
'poisson (proximal)',
'tonic')


def test_default_smoothing(setup_gui):
"""Tests default smoothing is inherited correctly"""

gui = setup_gui
gui.run_button.click()

# check that the unadjusted default smoothing is the same everywhere
gui_smooth_value = gui.fig_default_params['default_smoothing']
viz_smooth_value = gui.viz_manager.fig_default_params['default_smoothing']

assert gui_smooth_value == 30
assert viz_smooth_value == 30

# update simulation name
gui.widget_simulation_name.value = 'no_smoothing'

# change value of default smoothing in the widget
new_smoothing = 0
gui.widget_default_smoothing.value = new_smoothing

gui.run_button.click()

# check that the new default smoothing value is set everywhere
gui_smooth_value = gui.fig_default_params['default_smoothing']
viz_smooth_value = gui.viz_manager.fig_default_params['default_smoothing']

assert gui_smooth_value == new_smoothing
assert viz_smooth_value == new_smoothing

# check that dipole plots have data
gui._simulate_viz_action("switch_fig_template", "[Blank] single figure")
gui._simulate_viz_action("add_fig")
figid = 2
figname = f'Figure {figid}'
axname = 'ax0'

_dipole_plot_types = [
'current dipole',
'layer2 dipole',
'layer5 dipole',
]

for viz_type in _dipole_plot_types:
gui._simulate_viz_action(
"edit_figure", figname,
axname, 'no_smoothing', viz_type, {}, 'clear'
)

gui._simulate_viz_action(
"edit_figure", figname,
axname, 'no_smoothing', viz_type, {}, 'plot')

# Check if data is plotted on the axes
assert gui.viz_manager.figs[figid].axes[0].has_data()

plt.close('all')

0 comments on commit 8c1351e

Please sign in to comment.