From 5fa94194de07374b6b76d3e8cb72093094b3df41 Mon Sep 17 00:00:00 2001 From: NicolaCourtier <45851982+NicolaCourtier@users.noreply.github.com> Date: Thu, 25 Jul 2024 11:21:30 +0100 Subject: [PATCH 1/3] Add get_parameter_info with test --- CHANGELOG.md | 1 + pybop/models/base_model.py | 15 +++++++++++++++ tests/unit/test_models.py | 5 +++++ 3 files changed, 21 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 51d43e025..02ed8460a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Features +- [#418](https://github.com/pybop-team/PyBOP/issues/418) - Wraps the `get_parameter_info` method from PyBaMM to get a dictionary of parameter names and types. - [#327](https://github.com/pybop-team/PyBOP/issues/327) - Adds the `WeightedCost` subclass, defines when to evaluate a problem and adds the `spm_weighted_cost` example script. - [#393](https://github.com/pybop-team/PyBOP/pull/383) - Adds Minkowski and SumofPower cost classes, with an example and corresponding tests. - [#403](https://github.com/pybop-team/PyBOP/pull/403/) - Adds lychee link checking action. diff --git a/pybop/models/base_model.py b/pybop/models/base_model.py index a706923cb..75ccfbefa 100644 --- a/pybop/models/base_model.py +++ b/pybop/models/base_model.py @@ -718,3 +718,18 @@ def solver(self): @solver.setter def solver(self, solver): self._solver = solver.copy() if solver is not None else None + + def get_parameter_info(self): + """ + Extracts the parameter names and types and returns them as a dictionary. + """ + if not self.pybamm_model._built: + self.pybamm_model.build_model() + + info = self.pybamm_model.get_parameter_info() + + reduced_info = dict() + for param, param_type in info.values(): + param_name = getattr(param, "name", str(param)) + reduced_info[param_name] = param_type + return reduced_info diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index b12b3639e..08025a624 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -368,3 +368,8 @@ def test_non_converged_solution(self): for key in problem.signal: assert np.allclose(output.get(key, [])[0], output.get(key, [])) assert np.allclose(output_S1.get(key, [])[0], output_S1.get(key, [])) + + @pytest.mark.unit + def test_get_parameter_info(self, model): + parameter_info = model.get_parameter_info() + assert isinstance(parameter_info, dict) From 1128490fb6e9965682a1787b4e405062b0e5f53a Mon Sep 17 00:00:00 2001 From: NicolaCourtier <45851982+NicolaCourtier@users.noreply.github.com> Date: Thu, 25 Jul 2024 11:33:50 +0100 Subject: [PATCH 2/3] Add print_info option and test --- pybop/models/base_model.py | 7 ++++++- tests/unit/test_models.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/pybop/models/base_model.py b/pybop/models/base_model.py index 75ccfbefa..9e547b656 100644 --- a/pybop/models/base_model.py +++ b/pybop/models/base_model.py @@ -719,7 +719,7 @@ def solver(self): def solver(self, solver): self._solver = solver.copy() if solver is not None else None - def get_parameter_info(self): + def get_parameter_info(self, print_info: bool = False): """ Extracts the parameter names and types and returns them as a dictionary. """ @@ -732,4 +732,9 @@ def get_parameter_info(self): for param, param_type in info.values(): param_name = getattr(param, "name", str(param)) reduced_info[param_name] = param_type + + if print_info: + for param, param_type in info.values(): + print(param, " : ", param_type) + return reduced_info diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index 08025a624..560438004 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -1,3 +1,6 @@ +import sys +from io import StringIO + import numpy as np import pybamm import pytest @@ -373,3 +376,15 @@ def test_non_converged_solution(self): def test_get_parameter_info(self, model): parameter_info = model.get_parameter_info() assert isinstance(parameter_info, dict) + + captured_output = StringIO() + sys.stdout = captured_output + + model.get_parameter_info(print_info=True) + sys.stdout = sys.__stdout__ + + printed_messaage = captured_output.getvalue().strip() + + for key, value in parameter_info.items(): + assert key in printed_messaage + assert value in printed_messaage From 7e7e135963558a5ce126cf2857b7302b66efa023 Mon Sep 17 00:00:00 2001 From: NicolaCourtier <45851982+NicolaCourtier@users.noreply.github.com> Date: Thu, 25 Jul 2024 12:24:51 +0100 Subject: [PATCH 3/3] Test one without built pybamm model --- tests/unit/test_models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index 560438004..551379b71 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -374,6 +374,10 @@ def test_non_converged_solution(self): @pytest.mark.unit def test_get_parameter_info(self, model): + if isinstance(model, pybop.empirical.Thevenin): + # Test at least one model without a built pybamm model + model = pybop.empirical.Thevenin(build=False) + parameter_info = model.get_parameter_info() assert isinstance(parameter_info, dict)