Skip to content

Commit

Permalink
Merge pull request pybamm-team#3396 from kratman/feat/smoothmin
Browse files Browse the repository at this point in the history
Add smooth min and smooth max
  • Loading branch information
valentinsulzer authored Dec 21, 2023
2 parents 1943aa5 + 7177160 commit 7e72c38
Show file tree
Hide file tree
Showing 5 changed files with 373 additions and 117 deletions.
331 changes: 254 additions & 77 deletions docs/source/examples/notebooks/solvers/speed-up-solver.ipynb

Large diffs are not rendered by default.

38 changes: 31 additions & 7 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,11 +1239,14 @@ def minimum(left, right):
if out is not None:
return out

k = pybamm.settings.min_smoothing
mode = pybamm.settings.min_max_mode
k = pybamm.settings.min_max_smoothing
# Return exact approximation if that is the setting or the outcome is a constant
# (i.e. no need for smoothing)
if k == "exact" or (left.is_constant() and right.is_constant()):
if mode == "exact" or (left.is_constant() and right.is_constant()):
out = Minimum(left, right)
elif mode == "smooth":
out = pybamm.smooth_min(left, right, k)
else:
out = pybamm.softminus(left, right, k)
return pybamm.simplify_if_constant(out)
Expand All @@ -1260,11 +1263,14 @@ def maximum(left, right):
if out is not None:
return out

k = pybamm.settings.max_smoothing
mode = pybamm.settings.min_max_mode
k = pybamm.settings.min_max_smoothing
# Return exact approximation if that is the setting or the outcome is a constant
# (i.e. no need for smoothing)
if k == "exact" or (left.is_constant() and right.is_constant()):
if mode == "exact" or (left.is_constant() and right.is_constant()):
out = Maximum(left, right)
elif mode == "smooth":
out = pybamm.smooth_max(left, right, k)
else:
out = pybamm.softplus(left, right, k)
return pybamm.simplify_if_constant(out)
Expand Down Expand Up @@ -1311,20 +1317,38 @@ def _heaviside(left, right, equal):

def softminus(left, right, k):
"""
Softplus approximation to the minimum function. k is the smoothing parameter,
set by `pybamm.settings.min_smoothing`. The recommended value is k=10.
Softminus approximation to the minimum function. k is the smoothing parameter,
set by `pybamm.settings.min_max_smoothing`. The recommended value is k=10.
"""
return pybamm.log(pybamm.exp(-k * left) + pybamm.exp(-k * right)) / -k


def softplus(left, right, k):
"""
Softplus approximation to the maximum function. k is the smoothing parameter,
set by `pybamm.settings.max_smoothing`. The recommended value is k=10.
set by `pybamm.settings.min_max_smoothing`. The recommended value is k=10.
"""
return pybamm.log(pybamm.exp(k * left) + pybamm.exp(k * right)) / k


def smooth_min(left, right, k):
"""
Smooth_min approximation to the minimum function. k is the smoothing parameter,
set by `pybamm.settings.min_max_smoothing`. The recommended value is k=100.
"""
sigma = (1.0 / k)**2
return ((left + right) - (pybamm.sqrt((left - right)**2 + sigma))) / 2


def smooth_max(left, right, k):
"""
Smooth_max approximation to the maximum function. k is the smoothing parameter,
set by `pybamm.settings.min_max_smoothing`. The recommended value is k=100.
"""
sigma = (1.0 / k) ** 2
return (pybamm.sqrt((left - right)**2 + sigma) + (left + right)) / 2


def sigmoid(left, right, k):
"""
Sigmoidal approximation to the heaviside function. k is the smoothing parameter,
Expand Down
52 changes: 33 additions & 19 deletions pybamm/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
class Settings:
_debug_mode = False
_simplify = True
_min_smoothing = "exact"
_max_smoothing = "exact"
_min_max_mode = "exact"
_min_max_smoothing = 10
_heaviside_smoothing = "exact"
_abs_smoothing = "exact"
max_words_in_line = 4
Expand Down Expand Up @@ -43,35 +43,49 @@ def simplify(self, value):
self._simplify = value

def set_smoothing_parameters(self, k):
"Helper function to set all smoothing parameters"
self.min_smoothing = k
self.max_smoothing = k
"""Helper function to set all smoothing parameters"""
if k == "exact":
self.min_max_mode = "exact"
else:
self.min_max_smoothing = k
self.min_max_mode = "soft"
self.heaviside_smoothing = k
self.abs_smoothing = k

def check_k(self, k):
@staticmethod
def check_k(k):
if k != "exact" and k <= 0:
raise ValueError(
"smoothing parameter must be 'exact' or a strictly positive number"
"Smoothing parameter must be 'exact' or a strictly positive number"
)

@property
def min_smoothing(self):
return self._min_smoothing
def min_max_mode(self):
return self._min_max_mode

@min_smoothing.setter
def min_smoothing(self, k):
self.check_k(k)
self._min_smoothing = k
@min_max_mode.setter
def min_max_mode(self, mode):
if mode not in ["exact", "soft", "smooth"]:
raise ValueError(
"Smoothing mode must be 'exact', 'soft', or 'smooth'"
)
self._min_max_mode = mode

@property
def max_smoothing(self):
return self._max_smoothing
def min_max_smoothing(self):
return self._min_max_smoothing

@max_smoothing.setter
def max_smoothing(self, k):
self.check_k(k)
self._max_smoothing = k
@min_max_smoothing.setter
def min_max_smoothing(self, k):
if self._min_max_mode == "soft" and k <= 0:
raise ValueError(
"Smoothing parameter must be a strictly positive number"
)
if self._min_max_mode == "smooth" and k < 1:
raise ValueError(
"Smoothing parameter must be greater than 1"
)
self._min_max_smoothing = k

@property
def heaviside_smoothing(self):
Expand Down
46 changes: 42 additions & 4 deletions tests/unit/test_expression_tree/test_binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ def test_softminus_softplus(self):
)

# Test that smooth min/max are used when the setting is changed
pybamm.settings.min_smoothing = 10
pybamm.settings.max_smoothing = 10
pybamm.settings.min_max_mode = "soft"
pybamm.settings.min_max_smoothing = 10

self.assertEqual(str(pybamm.minimum(a, b)), str(pybamm.softminus(a, b, 10)))
self.assertEqual(str(pybamm.maximum(a, b)), str(pybamm.softplus(a, b, 10)))
Expand All @@ -425,8 +425,46 @@ def test_softminus_softplus(self):
self.assertEqual(str(pybamm.maximum(a, b)), str(b))

# Change setting back for other tests
pybamm.settings.min_smoothing = "exact"
pybamm.settings.max_smoothing = "exact"
pybamm.settings.set_smoothing_parameters("exact")

def test_smooth_minus_plus(self):
a = pybamm.Scalar(1)
b = pybamm.StateVector(slice(0, 1))

minimum = pybamm.smooth_min(a, b, 3000)
self.assertAlmostEqual(minimum.evaluate(y=np.array([2]))[0, 0], 1)
self.assertAlmostEqual(minimum.evaluate(y=np.array([0]))[0, 0], 0)

maximum = pybamm.smooth_max(a, b, 3000)
self.assertAlmostEqual(maximum.evaluate(y=np.array([2]))[0, 0], 2)
self.assertAlmostEqual(maximum.evaluate(y=np.array([0]))[0, 0], 1)

minimum = pybamm.smooth_min(a, b, 1)
self.assertEqual(
str(minimum),
"0.5 * (1.0 + y[0:1] - sqrt(1.0 + (1.0 - y[0:1]) ** 2.0))",
)
maximum = pybamm.smooth_max(a, b, 1)
self.assertEqual(
str(maximum),
"0.5 * (sqrt(1.0 + (1.0 - y[0:1]) ** 2.0) + 1.0 + y[0:1])",
)

# Test that smooth min/max are used when the setting is changed
pybamm.settings.min_max_mode = "smooth"

pybamm.settings.min_max_smoothing = 1
self.assertEqual(str(pybamm.minimum(a, b)), str(pybamm.smooth_min(a, b, 1)))
self.assertEqual(str(pybamm.maximum(a, b)), str(pybamm.smooth_max(a, b, 1)))

pybamm.settings.min_max_smoothing = 3000
a = pybamm.Scalar(1)
b = pybamm.Scalar(2)
self.assertEqual(str(pybamm.minimum(a, b)), str(a))
self.assertEqual(str(pybamm.maximum(a, b)), str(b))

# Change setting back for other tests
pybamm.settings.set_smoothing_parameters("exact")

def test_binary_simplifications(self):
a = pybamm.Scalar(0)
Expand Down
23 changes: 13 additions & 10 deletions tests/unit/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,30 @@ def test_simplify(self):
pybamm.settings.simplify = True

def test_smoothing_parameters(self):
self.assertEqual(pybamm.settings.min_smoothing, "exact")
self.assertEqual(pybamm.settings.max_smoothing, "exact")
self.assertEqual(pybamm.settings.min_max_mode, "exact")
self.assertEqual(pybamm.settings.heaviside_smoothing, "exact")
self.assertEqual(pybamm.settings.abs_smoothing, "exact")

pybamm.settings.set_smoothing_parameters(10)
self.assertEqual(pybamm.settings.min_smoothing, 10)
self.assertEqual(pybamm.settings.max_smoothing, 10)
self.assertEqual(pybamm.settings.min_max_smoothing, 10)
self.assertEqual(pybamm.settings.heaviside_smoothing, 10)
self.assertEqual(pybamm.settings.abs_smoothing, 10)
pybamm.settings.set_smoothing_parameters("exact")

# Test errors
with self.assertRaisesRegex(ValueError, "strictly positive"):
pybamm.settings.min_smoothing = -10
with self.assertRaisesRegex(ValueError, "strictly positive"):
pybamm.settings.max_smoothing = -10
with self.assertRaisesRegex(ValueError, "strictly positive"):
with self.assertRaisesRegex(ValueError, "greater than 1"):
pybamm.settings.min_max_mode = "smooth"
pybamm.settings.min_max_smoothing = 0.9
with self.assertRaisesRegex(ValueError, "positive number"):
pybamm.settings.min_max_mode = "soft"
pybamm.settings.min_max_smoothing = -10
with self.assertRaisesRegex(ValueError, "positive number"):
pybamm.settings.heaviside_smoothing = -10
with self.assertRaisesRegex(ValueError, "strictly positive"):
with self.assertRaisesRegex(ValueError, "positive number"):
pybamm.settings.abs_smoothing = -10
with self.assertRaisesRegex(ValueError, "'soft', or 'smooth'"):
pybamm.settings.min_max_mode = "unknown"
pybamm.settings.set_smoothing_parameters("exact")


if __name__ == "__main__":
Expand Down

0 comments on commit 7e72c38

Please sign in to comment.