Skip to content

Commit

Permalink
Add GroupedSPMe model (#584)
Browse files Browse the repository at this point in the history
* Add first go at grouped SPMe

* Update thicknesses

* Update overpotentials

* Add double-layer

* Update example

* Remove factors of length

* Fix domain parameters

* Update series resistance and etas

* Set constant conductivity

* Remove factor of 3

* Add parameter conversion function

* Combine plots

* Update model options

* Rename variables

* Add conductivities and init_soc

* Update averaging

* Move timescales

* Rescale electrolyte source

* Assume high conductivities

* Add electrolyte flux

* Add transfer coefficient

* Update grouped_SPMe

* Revert to sto_e, make double layer optional

* Add target continuity conditions

* Simplify ocv_init setting

* Remove concatenations

* Switch from theoretical to measured capacity

* Move example into subfolder

* Fix electrolyte scaling

* Add potentials to quick plot

* Update option setting

* Grouped SPMe edit (#577)

* Add Battery voltage to the variables list

* Create grouped_SPMe_experiment.py

* Move set_initial_state to base models

* Remove testing script

* Add tests on GroupedSPMe

* Test differential surface form

* Fix option setting

* Update README.md

* Add test_grouped_SPMe

* Increase coverage

* Update name and option setting

* Remove comments

* Combine if statements

* Update CHANGELOG.md

---------

Co-authored-by: Noël Hallemans <[email protected]>
  • Loading branch information
NicolaCourtier and noelhallemans authored Dec 11, 2024
1 parent 0b6ae29 commit 23aa9dd
Show file tree
Hide file tree
Showing 12 changed files with 938 additions and 36 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

## Features

- [#571](https://github.com/pybop-team/PyBOP/pull/571) - Adds Multistart functionality to optimisers via initialisation arg `multistart`
- [#584](https://github.com/pybop-team/PyBOP/pull/584) - Adds the `GroupedSPMe` model for parameter identification.
- [#571](https://github.com/pybop-team/PyBOP/pull/571) - Adds Multistart functionality to optimisers via initialisation arg `multistart`.
- [#582](https://github.com/pybop-team/PyBOP/pull/582) - Fixes `population_size` arg for Pints' based optimisers, reshapes `parameters.rvs` to be parameter instances.
- [#570](https://github.com/pybop-team/PyBOP/pull/570) - Updates the contour and surface plots, adds mixed chain effective sample size computation, x0 to optim.log
- [#566](https://github.com/pybop-team/PyBOP/pull/566) - Adds `UnitHyperCube` transformation class, fixes incorrect application of gradient transformation.
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ The table below lists the currently supported [models](https://github.com/pybop-
| Single Particle Model with Electrolyte (SPMe) | Particle Swarm Optimization (PSO) | Root Mean Squared Error (RMSE) <tr></tr> |
| Doyle-Fuller-Newman (DFN) | Exponential Natural Evolution Strategy (xNES) | Minkowski <tr></tr> |
| Many Particle Model (MPM) | Separable Natural Evolution Strategy (sNES) | Sum of Power <tr></tr> |
| Multi-Species Multi-Reactants (MSMR) | Adaptive Moment Estimation with Weight Decay (AdamW) | Gaussian Log Likelihood <tr></tr> |
| Multi-Species Multi-Reaction (MSMR) | Adaptive Moment Estimation with Weight Decay (AdamW) | Gaussian Log Likelihood <tr></tr> |
| Weppner-Huggins | Improved Resilient Backpropagation (iRProp-) | Log Posterior <tr></tr> |
| Equivalent Circuit Models (ECM) | SciPy Minimize & Differential Evolution | Unscented Kalman Filter (UKF) <tr></tr> |
| | Cuckoo Search | Gravimetric Energy Density <tr></tr> |
| Grouped-parameter SPMe (GroupedSPMe) | Cuckoo Search | Gravimetric Energy Density <tr></tr> |
| | Gradient Descent | Volumetric Energy Density<tr></tr> |
| | Nelder-Mead | <tr></tr> |

Expand Down
113 changes: 113 additions & 0 deletions examples/scripts/comparison_examples/grouped_SPMe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import matplotlib.pyplot as plt
import numpy as np

import pybop
from pybop.models.lithium_ion.basic_SPMe import convert_physical_to_grouped_parameters

# Prepare figure
layout_options = dict(
xaxis_title="Time / s",
yaxis_title="Voltage / V",
)
plot_dict = pybop.plot.StandardPlot(layout_options=layout_options)

# Unpack parameter values from Chen2020
parameter_set = pybop.ParameterSet.pybamm("Chen2020")

# Fix the electrolyte diffusivity and conductivity
ce0 = parameter_set["Initial concentration in electrolyte [mol.m-3]"]
T = parameter_set["Ambient temperature [K]"]
parameter_set["Electrolyte diffusivity [m2.s-1]"] = parameter_set[
"Electrolyte diffusivity [m2.s-1]"
](ce0, T)
parameter_set["Electrolyte conductivity [S.m-1]"] = parameter_set[
"Electrolyte conductivity [S.m-1]"
](ce0, T)

# Define a test protocol
initial_state = {"Initial SoC": 0.9}
experiment = pybop.Experiment(
[
"Discharge at 1C until 2.5 V (5 seconds period)",
"Rest for 30 minutes (5 seconds period)",
# "Charge at 2C until 4.1 V (5 seconds period)",
# "Rest for 30 minutes (5 seconds period)",
],
)

# Run an example SPMe simulation
model_options = {"surface form": "differential", "contact resistance": "true"}
time_domain_SPMe = pybop.lithium_ion.SPMe(
parameter_set=parameter_set,
options=model_options,
)
simulation = time_domain_SPMe.predict(
initial_state=initial_state, experiment=experiment
)
dataset = pybop.Dataset(
{
"Time [s]": simulation["Time [s]"].data,
"Current function [A]": simulation["Current [A]"].data,
"Voltage [V]": simulation["Voltage [V]"].data,
}
)
plot_dict.add_traces(dataset["Time [s]"], dataset["Voltage [V]"])

# Test model in the time domain
grouped_parameter_set = convert_physical_to_grouped_parameters(parameter_set)
time_domain_grouped = pybop.lithium_ion.GroupedSPMe(
parameter_set=grouped_parameter_set,
options=model_options,
build=True,
)
time_domain_grouped.set_initial_state(initial_state)
time_domain_grouped.set_current_function(dataset)
simulation = time_domain_grouped.predict(t_eval=dataset["Time [s]"])
dataset = pybop.Dataset(
{
"Time [s]": simulation["Time [s]"].data,
"Current function [A]": simulation["Current [A]"].data,
"Voltage [V]": simulation["Voltage [V]"].data,
}
)
plot_dict.add_traces(dataset["Time [s]"], dataset["Voltage [V]"], line_dash="dash")
plot_dict()

# Set up figure
fig, ax = plt.subplots()
ax.grid()

# Compare models in the frequency domain
freq_domain_SPMe = pybop.lithium_ion.SPMe(
parameter_set=parameter_set, options=model_options, eis=True
)
freq_domain_grouped = pybop.lithium_ion.GroupedSPMe(
parameter_set=grouped_parameter_set,
options=model_options,
eis=True,
build=True,
)

for i, model in enumerate([freq_domain_SPMe, freq_domain_grouped]):
NSOC = 11
Nfreq = 60
fmin = 4e-4
fmax = 1e3
SOCs = np.linspace(0, 1, NSOC)
frequencies = np.logspace(np.log10(fmin), np.log10(fmax), Nfreq)

impedances = 1j * np.zeros((Nfreq, NSOC))
for ii, SOC in enumerate(SOCs):
model.set_initial_state({"Initial SoC": SOC})
simulation = model.simulateEIS(inputs=None, f_eval=frequencies)
impedances[:, ii] = simulation["Impedance"]

if i == 0:
ax.plot(np.real(impedances[:, ii]), -np.imag(impedances[:, ii]), "b")
if i == 1:
ax.plot(np.real(impedances[:, ii]), -np.imag(impedances[:, ii]), "r--")

# Show figure
ax.set(xlabel=r"$Z_r(\omega)$ [$\Omega$]", ylabel=r"$-Z_j(\omega)$ [$\Omega$]")
ax.set_aspect("equal", "box")
plt.show()
35 changes: 6 additions & 29 deletions pybop/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,30 +241,7 @@ def set_initial_state(self, initial_state: dict, inputs: Optional[Inputs] = None
"""
self.clear()

initial_state = self.convert_to_pybamm_initial_state(initial_state)

if isinstance(self.pybamm_model, pybamm.equivalent_circuit.Thevenin):
initial_state = self.get_initial_state(initial_state, inputs=inputs)
self._unprocessed_parameter_set.update({"Initial SoC": initial_state})

else:
if not self.pybamm_model._built: # noqa: SLF001
self.pybamm_model.build_model()

# Temporary construction of attributes for PyBaMM
self._model = self.pybamm_model
self._unprocessed_parameter_values = self._unprocessed_parameter_set

# Set initial state via PyBaMM's Simulation class
pybamm.Simulation.set_initial_soc(self, initial_state, inputs=inputs)

# Update the default parameter set for consistency
self._unprocessed_parameter_set = self._parameter_values

# Clear the pybamm objects
del self._model
del self._unprocessed_parameter_values
del self._parameter_values
self._set_initial_state(initial_state=initial_state, inputs=inputs)

# Use a copy of the updated default parameter set
self._parameter_set = self._unprocessed_parameter_set.copy()
Expand Down Expand Up @@ -331,11 +308,11 @@ def set_up_for_eis(self, model):
model.param, None, model.options, control="algebraic"
).get_fundamental_variables()

# Perform the replacement
symbol_replacement_map = {
model.variables[name]: variable
for name, variable in external_circuit_variables.items()
}
# Define the variables to replace
symbol_replacement_map = {}
for name, variable in external_circuit_variables.items():
if name in model.variables.keys():
symbol_replacement_map[model.variables[name]] = variable

# Don't replace initial conditions, as these should not contain
# Variable objects
Expand Down
18 changes: 18 additions & 0 deletions pybop/models/empirical/base_ecm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import numpy as np
import pybamm

Expand Down Expand Up @@ -121,6 +123,22 @@ def _check_params(
return self.param_checker(inputs, allow_infeasible_solutions)
return True

def _set_initial_state(self, initial_state: dict, inputs: Optional[Inputs] = None):
"""
Set the initial state of charge or concentrations for the battery model.
Parameters
----------
initial_state : dict
A valid initial state, e.g. the initial state of charge or open-circuit voltage.
inputs : Inputs
The input parameters to be used when building the model.
"""
initial_state = self.convert_to_pybamm_initial_state(initial_state)

initial_state = self.get_initial_state(initial_state, inputs=inputs)
self._unprocessed_parameter_set.update({"Initial SoC": initial_state})

def get_initial_state(
self,
initial_value,
Expand Down
2 changes: 1 addition & 1 deletion pybop/models/lithium_ion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# Import lithium ion based models
#
from .base_echem import EChemBaseModel
from .echem import SPM, SPMe, DFN, MPM, MSMR, WeppnerHuggins
from .echem import SPM, SPMe, DFN, MPM, MSMR, WeppnerHuggins, GroupedSPMe
33 changes: 32 additions & 1 deletion pybop/models/lithium_ion/base_echem.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from typing import Optional

from pybamm import LithiumIonParameters
from pybamm import LithiumIonParameters, Simulation
from pybamm import lithium_ion as pybamm_lithium_ion

from pybop.models.base_model import BaseModel, Inputs
Expand Down Expand Up @@ -154,6 +154,37 @@ def _check_params(

return True

def _set_initial_state(self, initial_state: dict, inputs: Optional[Inputs] = None):
"""
Set the initial state of charge or concentrations for the battery model.
Parameters
----------
initial_state : dict
A valid initial state, e.g. the initial state of charge or open-circuit voltage.
inputs : Inputs
The input parameters to be used when building the model.
"""
initial_state = self.convert_to_pybamm_initial_state(initial_state)

if not self.pybamm_model._built: # noqa: SLF001
self.pybamm_model.build_model()

# Temporary construction of attributes for PyBaMM
self._model = self.pybamm_model
self._unprocessed_parameter_values = self._unprocessed_parameter_set

# Set initial state via PyBaMM's Simulation class
Simulation.set_initial_soc(self, initial_state, inputs=inputs)

# Update the default parameter set for consistency
self._unprocessed_parameter_set = self._parameter_values

# Clear the pybamm objects
del self._model
del self._unprocessed_parameter_values
del self._parameter_values

def cell_volume(self, parameter_set: Optional[ParameterSet] = None):
"""
Calculate the total cell volume in m3.
Expand Down
Loading

0 comments on commit 23aa9dd

Please sign in to comment.