Skip to content

Commit

Permalink
Merge pull request #160 from pybamm-team/dev
Browse files Browse the repository at this point in the history
Remove Dask
  • Loading branch information
TomTranter authored Jul 1, 2022
2 parents adcd594 + c5e86f2 commit 16b7fda
Show file tree
Hide file tree
Showing 10 changed files with 3 additions and 178 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,6 @@ ENV/
env.bak/
venv.bak/

# Dask files
dask-worker-space/

# Generated documentation directory
site/
examples/csv-results/
Expand Down
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ dependencies:
- click
- autograd
- absl-py
- dask
- hiredis
- pip:
- pybamm
Expand Down
46 changes: 0 additions & 46 deletions examples/dask_solver.py

This file was deleted.

1 change: 0 additions & 1 deletion liionpack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from .definitions import CIRCUIT_DIR
from .solvers import CasadiManager
from .solvers import RayManager
from .solvers import DaskManager
from .solvers import GenericActor
from .solvers import RayActor

Expand Down
4 changes: 1 addition & 3 deletions liionpack/solver_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def solve(
output_variables (list):
Variables to evaluate during solve. Must be a valid key in the
model.variables
manager (string, can be - ["casadi", "ray", "dask"]):
manager (string, can be - ["casadi", "ray"]):
The solver manager to use for solving the electrochemical problem.
Returns:
Expand All @@ -419,8 +419,6 @@ def solve(
rm = lp.CasadiManager()
elif manager == "ray":
rm = lp.RayManager()
elif manager == "dask":
rm = lp.DaskManager()
else:
rm = lp.CasadiManager()
lp.logger.notice("manager instruction not supported, using default")
Expand Down
109 changes: 0 additions & 109 deletions liionpack/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import ray
import numpy as np
import time as ticker
from dask.distributed import Client
from tqdm import tqdm
import pybamm

Expand Down Expand Up @@ -683,111 +682,3 @@ def log_event(self):

def cleanup(self):
pass


class DaskManager(GenericManager):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def split_models(self, Nspm, nproc):
# Manage the number of SPM models per worker
self.split_index = np.array_split(np.arange(Nspm), nproc)
self.spm_per_worker = [len(s) for s in self.split_index]
self.slices = []
for i in range(nproc):
self.slices.append(
slice(self.split_index[i][0], self.split_index[i][-1] + 1)
)

def setup_actors(self, nproc, inputs, initial_soc):
# Set up a casadi actor on each process
lp.logger.notice("Dask initialization started")
self.client = Client(n_workers=nproc)
lp.logger.notice("Dask initialization complete")
tic = ticker.time()
futures = []
for i in range(nproc):
# Create actor on each worker containing a simulation
futures.append(self.client.submit(GenericActor, actor=True, pure=False))
self.actors = [af.result() for af in futures]
futures = []
for i, a in enumerate(self.actors):
futures.append(
a.setup(
Nspm=self.spm_per_worker[i],
sim_func=self.sim_func,
parameter_values=self.parameter_values,
inputs=inputs[self.slices[i]],
dt=self.dt,
variable_names=self.variable_names,
initial_soc=initial_soc,
nproc=1,
)
)

_ = [af.result() for af in futures]
toc = ticker.time()
lp.logger.info(
"Dask actors setup in time " + str(np.around(toc - tic, 3)) + "s"
)

def step_actors(self):
tic = ticker.time()
inputs = self.build_inputs()
future_steps = []
for i, a in enumerate(self.actors):
future_steps.append(a.step(inputs=inputs[i]))
events = [af.result() for af in future_steps]
if np.any(events):
self.log_event()
toc = ticker.time()
lp.logger.info(
"Dask actors stepped in time " + str(np.around(toc - tic, 3)) + "s"
)

def evaluate_actors(self):
tic = ticker.time()
inputs = self.build_inputs()
future_evals = []
for i, a in enumerate(self.actors):
future_evals.append(a.evaluate(inputs=inputs[i]))
_ = [af.result() for af in future_evals]
toc = ticker.time()
lp.logger.info(
"Dask actors evaluated in time " + str(np.around(toc - tic, 3)) + "s"
)

def get_actor_output(self, step):
tic = ticker.time()
future_gets = []
for i, a in enumerate(self.actors):
future_gets.append(a.output())
for i, fg in enumerate(future_gets):
out = fg.result()
self.output[:, step, self.split_index[i]] = out
toc = ticker.time()
lp.logger.info(
"Dask,actors output got in time " + str(np.around(toc - tic, 3)) + "s"
)

def log_event(self):
futures = []
for actor in self.actors:
futures.append(actor.get_event_change())
all_event_changes = []
for i, f in enumerate(futures):
all_event_changes.append(np.asarray(f.result()))
event_change = np.hstack(all_event_changes)
Nr, Nc = event_change.shape
event_names = self.actors[0].get_event_names().result()
for r in range(Nr):
if np.any(event_change[r, :]):
lp.logger.warning(
event_names[r]
+ ", Batteries: "
+ str(np.where(event_change[r, :])[0].tolist())
)

def cleanup(self):
lp.logger.notice("Shutting down Dask client")
self.client.shutdown()
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ Ipython
scikit-spatial
networkx
textwrapper
dask[complete]
ray
redis
protobuf >= 3.8.0, < 4.0.0
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
"scikit-spatial",
"networkx",
"textwrapper",
"dask[complete]",
"ray",
"redis",
],
Expand Down
13 changes: 1 addition & 12 deletions tests/integration/test_all_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,8 @@ def test_consistent_results_1_step(self):
nproc=1,
manager="casadi"
)
# Solve pack with dask
b = lp.solve(
netlist=netlist,
parameter_values=parameter_values,
experiment=experiment,
inputs=None,
nproc=1,
manager="dask"
)
# Solve pack with ray
c = lp.solve(
b = lp.solve(
netlist=netlist,
parameter_values=parameter_values,
experiment=experiment,
Expand All @@ -50,10 +41,8 @@ def test_consistent_results_1_step(self):

v_a = a["Terminal voltage [V]"]
v_b = b["Terminal voltage [V]"]
v_c = c["Terminal voltage [V]"]

assert np.allclose(v_a, v_b)
assert np.allclose(v_b, v_c)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def setUpClass(self):
)
# PyBaMM parameters
self.parameter_values = pybamm.ParameterValues("Chen2020")
self.managers = ["casadi", "ray", "dask"]
self.managers = ["casadi", "ray"]

def test_multiprocessing(self):
for manager in self.managers:
Expand Down

0 comments on commit 16b7fda

Please sign in to comment.