Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds AdamW optimiser #316

Merged
merged 16 commits into from
Jun 10, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'develop' into 295-add-adamw-optimiser
BradyPlanden committed May 20, 2024
commit 877abbc7335e5ec2e149ed6139e95d5d9368e499
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -3,6 +3,8 @@
## Features

- [#316](https://github.com/pybop-team/PyBOP/pull/316) - Adds Adam with weight decay (AdamW) optimiser, replaces pints.Adam implementation.
BradyPlanden marked this conversation as resolved.
Show resolved Hide resolved
- [#321](https://github.com/pybop-team/PyBOP/pull/321) - Updates Prior classes with BaseClass, adds a `problem.sample_initial_conditions` method to improve stability of SciPy.Minimize optimiser.
- [#249](https://github.com/pybop-team/PyBOP/pull/249) - Add WeppnerHuggins model and GITT example.
- [#304](https://github.com/pybop-team/PyBOP/pull/304) - Decreases the testing suite completion time.
- [#301](https://github.com/pybop-team/PyBOP/pull/301) - Updates default echem solver to "fast with events" mode.
- [#251](https://github.com/pybop-team/PyBOP/pull/251) - Increment PyBaMM > v23.5, remove redundant tests within integration tests, increment citation version, fix examples with incorrect model definitions.

Unchanged files with check annotations Beta

def _log_init(self, logger):
"""See :meth:`Loggable._log_init()`."""
logger.add_float("b1")
logger.add_float("b2")
logger.add_float("lambda")

Check warning on line 130 in pybop/optimisers/_adamw.py

Codecov / codecov/patch

pybop/optimisers/_adamw.py#L128-L130

Added lines #L128 - L130 were not covered by tests
def _log_write(self, logger):
"""See :meth:`Loggable._log_write()`."""
logger.log(self._b1t)
logger.log(self._b2t)
logger.log(self._lambda)

Check warning on line 136 in pybop/optimisers/_adamw.py

Codecov / codecov/patch

pybop/optimisers/_adamw.py#L134-L136

Added lines #L134 - L136 were not covered by tests
def name(self):
"""
"""
The number of hyper-parameters used by this optimiser.
"""
return 1

Check warning on line 155 in pybop/optimisers/_adamw.py

Codecov / codecov/patch

pybop/optimisers/_adamw.py#L155

Added line #L155 was not covered by tests
def running(self):
"""
Returns ``True`` if the optimisation is in progress.
"""
return self._running

Check warning on line 161 in pybop/optimisers/_adamw.py

Codecov / codecov/patch

pybop/optimisers/_adamw.py#L161

Added line #L161 was not covered by tests
def tell(self, reply):
"""
# Check ask-tell pattern
if not self._ready_for_tell:
raise Exception("ask() not called before tell()")

Check warning on line 172 in pybop/optimisers/_adamw.py

Codecov / codecov/patch

pybop/optimisers/_adamw.py#L172

Added line #L172 was not covered by tests
self._ready_for_tell = False
# Unpack reply
"""
Returns the last guessed parameter values.
"""
return self._current

Check warning on line 219 in pybop/optimisers/_adamw.py

Codecov / codecov/patch

pybop/optimisers/_adamw.py#L219

Added line #L219 was not covered by tests
def set_lambda(self, lambda_=0.01):
"""
raise TypeError("lambda_ must be numeric, floatable value.")
if not 0 < lambda_ <= 1:
print("lambda_ must a positive value between 0 and 1")

Check warning on line 232 in pybop/optimisers/_adamw.py

Codecov / codecov/patch

pybop/optimisers/_adamw.py#L232

Added line #L232 was not covered by tests
self._lambda = lambda_
return
raise TypeError("b1 must be numeric, floatable value.")
if not 0 < b1 <= 1:
print("b1 must a positive value between 0 and 1")

Check warning on line 247 in pybop/optimisers/_adamw.py

Codecov / codecov/patch

pybop/optimisers/_adamw.py#L247

Added line #L247 was not covered by tests
self._b1 = b1
return
raise TypeError("b2 must be numeric, floatable value.")
if not 0 < b2 <= 1:
print("b2 must a positive value between 0 and 1")

Check warning on line 262 in pybop/optimisers/_adamw.py

Codecov / codecov/patch

pybop/optimisers/_adamw.py#L262

Added line #L262 was not covered by tests
self._b2 = b2
return
You are viewing a condensed version of this merge commit. You can view the full changes here.