-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ENH: Refactor
BatchSimulate
Example and Improve Documentation (#857)
* ENH BatchSimulate for JSON path handling Signed-off-by: samadpls <[email protected]> * Fix: Docstring of dic_to_network Signed-off-by: samadpls <[email protected]> * ENH: Reorder the init params Signed-off-by: samadpls <[email protected]> * Added record_isec vsec validation test Signed-off-by: samadpls <[email protected]> * Removed net_json param and update test Signed-off-by: samadpls <[email protected]> * Refactored Example and initialize parameters Co-authored-by: Nicholas Tolley <[email protected]> Signed-off-by: samadpls <[email protected]> * Enh: visualization of dipole responses in plot_batch_simulate Co-authored-by: Nicholas Tolley <[email protected]> Signed-off-by: samadpls <[email protected]> * Refactor batch simulation parameters and backend Signed-off-by: samadpls <[email protected]> * [MRG] Fix indexing for batch simulations (#5) * ENH BatchSimulate for JSON path handling Signed-off-by: samadpls <[email protected]> * Fix: Docstring of dic_to_network Signed-off-by: samadpls <[email protected]> * ENH: Reorder the init params Signed-off-by: samadpls <[email protected]> * Added record_isec vsec validation test Signed-off-by: samadpls <[email protected]> * Removed net_json param and update test Signed-off-by: samadpls <[email protected]> * Refactored Example and initialize parameters Co-authored-by: Nicholas Tolley <[email protected]> Signed-off-by: samadpls <[email protected]> * Enh: visualization of dipole responses in plot_batch_simulate Co-authored-by: Nicholas Tolley <[email protected]> Signed-off-by: samadpls <[email protected]> * Refactor batch simulation parameters and backend Signed-off-by: samadpls <[email protected]> * batches run in parallel --------- Signed-off-by: samadpls <[email protected]> Co-authored-by: samadpls <[email protected]> * Refactor: Removed joblib from simulate_dipole, and added parallel execution test. Signed-off-by: samadpls <[email protected]> * Refactor: Simplify BatchSimulate parameters by removing n_jobs Signed-off-by: samadpls <[email protected]> * Remove unused Dask code and simplify BatchSimulate initialization Signed-off-by: samadpls <[email protected]> --------- Signed-off-by: samadpls <[email protected]> Co-authored-by: Nicholas Tolley <[email protected]>
- Loading branch information
Showing
4 changed files
with
132 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,13 +3,18 @@ | |
# Ryan Thorpe <[email protected]> | ||
# Mainak Jas <[email protected]> | ||
|
||
from pathlib import Path | ||
import time | ||
import pytest | ||
import numpy as np | ||
import os | ||
|
||
from hnn_core.batch_simulate import BatchSimulate | ||
from hnn_core import jones_2009_model | ||
|
||
hnn_core_root = Path(__file__).parents[1] | ||
assets_path = Path(hnn_core_root, 'tests', 'assets') | ||
|
||
|
||
@pytest.fixture | ||
def batch_simulate_instance(tmp_path): | ||
|
@@ -33,11 +38,12 @@ def set_params(param_values, net): | |
weights_ampa=weights_ampa, | ||
synaptic_delays=synaptic_delays) | ||
|
||
net = jones_2009_model() | ||
net = jones_2009_model(mesh_shape=(3, 3)) | ||
return BatchSimulate(net=net, set_params=set_params, | ||
tstop=1., | ||
tstop=10, | ||
save_folder=tmp_path, | ||
batch_size=3) | ||
batch_size=3, | ||
n_trials=3,) | ||
|
||
|
||
@pytest.fixture | ||
|
@@ -75,6 +81,12 @@ def test_parameter_validation(): | |
with pytest.raises(TypeError, match="net must be"): | ||
BatchSimulate(net="invalid_network", set_params=lambda x: x) | ||
|
||
with pytest.raises(ValueError, match="'record_vsec' parameter"): | ||
BatchSimulate(set_params=lambda x: x, record_vsec="invalid") | ||
|
||
with pytest.raises(ValueError, match="'record_isec' parameter"): | ||
BatchSimulate(set_params=lambda x: x, record_isec="invalid") | ||
|
||
|
||
def test_generate_param_combinations(batch_simulate_instance, param_grid): | ||
"""Test generating parameter combinations.""" | ||
|
@@ -280,3 +292,27 @@ def test_load_results(batch_simulate_instance, param_grid, tmp_path): | |
# Validation Tests | ||
with pytest.raises(TypeError, match='results must be'): | ||
batch_simulate_instance._save("invalid_results", start_idx, end_idx) | ||
|
||
|
||
def test_parallel_execution(batch_simulate_instance, param_grid): | ||
"""Test parallel execution of simulations and ensure speedup.""" | ||
|
||
param_combinations = batch_simulate_instance._generate_param_combinations( | ||
param_grid) | ||
|
||
start_time = time.perf_counter() | ||
_ = batch_simulate_instance.simulate_batch( | ||
param_combinations, n_jobs=1, backend='loky') | ||
end_time = time.perf_counter() | ||
serial_time = end_time - start_time | ||
|
||
start_time = time.perf_counter() | ||
_ = batch_simulate_instance.simulate_batch( | ||
param_combinations, | ||
n_jobs=2, | ||
backend='loky') | ||
end_time = time.perf_counter() | ||
parallel_time = end_time - start_time | ||
|
||
assert (serial_time > parallel_time | ||
), "Parallel execution is not faster than serial execution!" |