Skip to content

Commit

Permalink
Merge pull request #2443 from pybamm-team/more-simplifications
Browse files Browse the repository at this point in the history
More simplifications
  • Loading branch information
valentinsulzer authored Nov 10, 2022
2 parents d5547da + 37b96d0 commit c0aa28d
Show file tree
Hide file tree
Showing 40 changed files with 401 additions and 352 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

## Optimizations

- Added more rules for simplifying expressions, especially around Concatenations. Also, meshes constructed from multiple domains are now cached ([#2443](https://github.com/pybamm-team/PyBaMM/pull/2443))
- Added more rules for simplifying expressions. Constants in binary operators are now moved to the left by default (e.g. `x*2` returns `2*x`) ([#2424](https://github.com/pybamm-team/PyBaMM/pull/2424))

## Breaking changes
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/compare_comsol/compare_comsol_DFN.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_interp_fun(variable_name, domain):
comsol_x = comsol_variables["x"]

# Make sure to use dimensional space
pybamm_x = mesh.combine_submeshes(*domain).nodes * L_x
pybamm_x = mesh[domain].nodes * L_x
variable = interp.interp1d(comsol_x, variable, axis=0)(pybamm_x)

fun = pybamm.Interpolant(
Expand All @@ -88,7 +88,7 @@ def get_interp_fun(variable_name, domain):
)

fun.domains = {"primary": domain}
fun.mesh = mesh.combine_submeshes(*domain)
fun.mesh = mesh[domain]
fun.secondary_mesh = None
return fun

Expand Down
23 changes: 11 additions & 12 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,10 @@ def process_dict(self, var_eqn_dict):
for eqn_key, eqn in var_eqn_dict.items():
# Broadcast if the equation evaluates to a number (e.g. Scalar)
if np.prod(eqn.shape_for_testing) == 1 and not isinstance(eqn_key, str):
eqn = pybamm.FullBroadcast(eqn, broadcast_domains=eqn_key.domains)
if eqn_key.domain == []:
eqn = eqn * pybamm.Vector([1])
else:
eqn = pybamm.FullBroadcast(eqn, broadcast_domains=eqn_key.domains)

pybamm.logger.debug("Discretise {!r}".format(eqn_key))

Expand Down Expand Up @@ -784,14 +787,14 @@ def process_symbol(self, symbol):

# Assign mesh as an attribute to the processed variable
if symbol.domain != []:
discretised_symbol.mesh = self.mesh.combine_submeshes(*symbol.domain)
discretised_symbol.mesh = self.mesh[symbol.domain]
else:
discretised_symbol.mesh = None
# Assign secondary mesh
if symbol.domains["secondary"] != []:
discretised_symbol.secondary_mesh = self.mesh.combine_submeshes(
*symbol.domains["secondary"]
)
discretised_symbol.secondary_mesh = self.mesh[
symbol.domains["secondary"]
]
else:
discretised_symbol.secondary_mesh = None
return discretised_symbol
Expand Down Expand Up @@ -897,13 +900,9 @@ def _process_symbol(self, symbol):
elif isinstance(symbol, pybamm.Broadcast):
# Broadcast new_child to the domain specified by symbol.domain
# Different discretisations may broadcast differently
if symbol.domain == []:
out = disc_child * pybamm.Vector([1])
else:
out = spatial_method.broadcast(
disc_child, symbol.domains, symbol.broadcast_type
)
return out
return spatial_method.broadcast(
disc_child, symbol.domains, symbol.broadcast_type
)

elif isinstance(symbol, pybamm.DeltaFunction):
return spatial_method.delta_function(symbol, disc_child)
Expand Down
42 changes: 14 additions & 28 deletions pybamm/expression_tree/averages.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,36 +144,22 @@ def x_average(symbol):
else: # pragma: no cover
# It should be impossible to get here
raise NotImplementedError
# If symbol is a concatenation of Broadcasts, its average value is the
# thickness-weighted average of the symbols being broadcasted
elif isinstance(symbol, pybamm.Concatenation) and all(
isinstance(child, pybamm.Broadcast) for child in symbol.children
# If symbol is a concatenation, its average value is the
# thickness-weighted average of the average of its children
elif isinstance(symbol, pybamm.Concatenation) and not isinstance(
symbol, pybamm.ConcatenationVariable
):
geo = pybamm.geometric_parameters
l_n = geo.n.l
l_s = geo.s.l
l_p = geo.p.l
if symbol.domain == ["negative electrode", "separator", "positive electrode"]:
a, b, c = [orp.orphans[0] for orp in symbol.orphans]
out = (l_n * a + l_s * b + l_p * c) / (l_n + l_s + l_p)
elif symbol.domain == ["separator", "positive electrode"]:
b, c = [orp.orphans[0] for orp in symbol.orphans]
out = (l_s * b + l_p * c) / (l_s + l_p)
# To respect domains we may need to broadcast the child back out
child = symbol.children[0]
# If symbol being returned doesn't have empty domain, return it
if out.domain != []:
return out
# Otherwise we may need to broadcast it
elif child.domains["secondary"] == []:
return out
else:
domain = child.domains["secondary"]
if child.domains["tertiary"] == []:
return pybamm.PrimaryBroadcast(out, domain)
else:
auxiliary_domains = {"secondary": child.domains["tertiary"]}
return pybamm.FullBroadcast(out, domain, auxiliary_domains)
ls = {
("negative electrode",): geo.n.l,
("separator",): geo.s.l,
("positive electrode",): geo.p.l,
("separator", "positive electrode"): geo.s.l + geo.p.l,
}
out = sum(
ls[tuple(orp.domain)] * x_average(orp) for orp in symbol.orphans
) / sum(ls[tuple(orp.domain)] for orp in symbol.orphans)
return out
# Average of a sum is sum of averages
elif isinstance(symbol, (pybamm.Addition, pybamm.Subtraction)):
return _sum_of_averages(symbol, x_average)
Expand Down
12 changes: 2 additions & 10 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,16 +711,8 @@ def _simplified_binary_broadcast_concatenation(left, right, operator):
return left._concatenation_new_copy(
[operator(child, right) for child in left.orphans]
)
elif (
isinstance(right, pybamm.Concatenation)
and not any(
isinstance(child, (pybamm.Variable, pybamm.StateVector))
for child in right.children
)
and (
all(child.is_constant() for child in left.children)
or all(child.is_constant() for child in right.children)
)
elif isinstance(right, pybamm.Concatenation) and not isinstance(
right, pybamm.ConcatenationVariable
):
return left._concatenation_new_copy(
[
Expand Down
9 changes: 8 additions & 1 deletion pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def check_and_set_domains(self, child, broadcast_domain):
# Can only do primary broadcast from current collector to electrode,
# particle-size or particle or from electrode to particle-size or particle.
# Note e.g. current collector to particle *is* allowed
if broadcast_domain == []:
raise pybamm.DomainError("Cannot Broadcast an object into empty domain.")
if child.domain == []:
pass
elif child.domain == ["current collector"] and not (
Expand Down Expand Up @@ -430,7 +432,10 @@ def __init__(

def check_and_set_domains(self, child, broadcast_domains):
"""See :meth:`Broadcast.check_and_set_domains`"""

if broadcast_domains["primary"] == []:
raise pybamm.DomainError(
"""Cannot do full broadcast to an empty primary domain"""
)
# Variables on the current collector can only be broadcast to 'primary'
if child.domain == ["current collector"]:
raise pybamm.DomainError(
Expand Down Expand Up @@ -544,6 +549,8 @@ def full_like(symbols, fill_value):
return array_type(entries, domains=sum_symbol.domains)

except NotImplementedError:
if sum_symbol.shape_for_testing == (1, 1):
return pybamm.Scalar(fill_value)
if sum_symbol.evaluates_on_edges("primary"):
return FullBroadcastToEdges(
fill_value, broadcast_domains=sum_symbol.domains
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def _get_auxiliary_domain_repeats(self, auxiliary_domains):
mesh_pts = 1
for level, dom in auxiliary_domains.items():
if level != "primary" and dom != []:
mesh_pts *= self.full_mesh.combine_submeshes(*dom).npts
mesh_pts *= self.full_mesh[dom].npts
return mesh_pts

@property
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def __abs__(self):
elif isinstance(self, pybamm.Broadcast):
# Move absolute value inside the broadcast
# Apply recursively
abs_self_not_broad = pybamm.simplify_if_constant(abs(self.orphans[0]))
abs_self_not_broad = abs(self.orphans[0])
return self._unary_new_copy(abs_self_not_broad)
else:
k = pybamm.settings.abs_smoothing
Expand Down
21 changes: 19 additions & 2 deletions pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,9 @@ def __init__(self, children, initial_condition):
def _unary_new_copy(self, child):
return self.__class__(child, self.initial_condition)

def is_constant(self):
return False


class BoundaryGradient(BoundaryOperator):
"""
Expand Down Expand Up @@ -1084,7 +1087,10 @@ def grad(symbol):
"""
# Gradient of a broadcast is zero
if isinstance(symbol, pybamm.PrimaryBroadcast):
new_child = pybamm.PrimaryBroadcast(0, symbol.child.domain)
if symbol.child.domain == []:
new_child = pybamm.Scalar(0)
else:
new_child = pybamm.PrimaryBroadcast(0, symbol.child.domain)
return pybamm.PrimaryBroadcastToEdges(new_child, symbol.domain)
elif isinstance(symbol, pybamm.FullBroadcast):
return pybamm.FullBroadcastToEdges(0, broadcast_domains=symbol.domains)
Expand All @@ -1110,7 +1116,10 @@ def div(symbol):
"""
# Divergence of a broadcast is zero
if isinstance(symbol, pybamm.PrimaryBroadcastToEdges):
new_child = pybamm.PrimaryBroadcast(0, symbol.child.domain)
if symbol.child.domain == []:
new_child = pybamm.Scalar(0)
else:
new_child = pybamm.PrimaryBroadcast(0, symbol.child.domain)
return pybamm.PrimaryBroadcast(new_child, symbol.domain)
# Divergence commutes with Negate operator
if isinstance(symbol, pybamm.Negate):
Expand Down Expand Up @@ -1245,6 +1254,14 @@ def boundary_value(symbol, side):

def sign(symbol):
"""Returns a :class:`Sign` object."""
if isinstance(symbol, pybamm.Broadcast):
# Move sign inside the broadcast
# Apply recursively
return symbol._unary_new_copy(sign(symbol.orphans[0]))
elif isinstance(symbol, pybamm.Concatenation) and not isinstance(
symbol, pybamm.ConcatenationVariable
):
return pybamm.concatenation(*[sign(child) for child in symbol.orphans])
return pybamm.simplify_if_constant(Sign(symbol))


Expand Down
32 changes: 24 additions & 8 deletions pybamm/meshes/meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,30 @@ def __init__(self, geometry, submesh_types, var_pts):
geometry[domain][spatial_variable][lim] = sym_eval

# Create submeshes
self.base_domains = []
for domain in geometry:
self[domain] = submesh_types[domain](geometry[domain], submesh_pts[domain])
self.base_domains.append(domain)

# add ghost meshes
self.add_ghost_meshes()

def __getitem__(self, domains):
if isinstance(domains, str):
domains = (domains,)
domains = tuple(domains)
try:
return super().__getitem__(domains)
except KeyError:
value = self.combine_submeshes(*domains)
self[domains] = value
return value

def __setitem__(self, domains, value):
if isinstance(domains, str):
domains = (domains,)
super().__setitem__(domains, value)

def combine_submeshes(self, *submeshnames):
"""Combine submeshes into a new submesh, using self.submeshclass
Raises pybamm.DomainError if submeshes to be combined do not match up (edges are
Expand All @@ -134,9 +152,6 @@ def combine_submeshes(self, *submeshnames):
"""
if submeshnames == ():
raise ValueError("Submesh domains being combined cannot be empty")
# If there is just a single submesh, we can return it directly
if len(submeshnames) == 1:
return self[submeshnames[0]]
# Check that the final edge of each submesh is the same as the first edge of the
# next submesh
for i in range(len(submeshnames) - 1):
Expand All @@ -159,7 +174,6 @@ def combine_submeshes(self, *submeshnames):
submesh.internal_boundaries = [
self[submeshname].edges[0] for submeshname in submeshnames[1:]
]

return submesh

def add_ghost_meshes(self):
Expand All @@ -172,22 +186,24 @@ def add_ghost_meshes(self):
submeshes = [
(domain, submesh)
for domain, submesh in self.items()
if not isinstance(submesh, (pybamm.SubMesh0D, pybamm.ScikitSubMesh2D))
if (
len(domain) == 1
and not isinstance(submesh, (pybamm.SubMesh0D, pybamm.ScikitSubMesh2D))
)
]
for domain, submesh in submeshes:

edges = submesh.edges

# left ghost cell: two edges, one node, to the left of existing submesh
lgs_edges = np.array([2 * edges[0] - edges[1], edges[0]])
self[domain + "_left ghost cell"] = pybamm.SubMesh1D(
self[domain[0] + "_left ghost cell"] = pybamm.SubMesh1D(
lgs_edges, submesh.coord_sys
)

# right ghost cell: two edges, one node, to the right of
# existing submesh
rgs_edges = np.array([edges[-1], 2 * edges[-1] - edges[-2]])
self[domain + "_right ghost cell"] = pybamm.SubMesh1D(
self[domain[0] + "_right ghost cell"] = pybamm.SubMesh1D(
rgs_edges, submesh.coord_sys
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,17 @@ def _get_standard_concentration_variables(self, c_e_dict):
electrolyte.
"""

c_e_typ = self.param.c_e_typ
c_e = pybamm.concatenation(*c_e_dict.values())
# Override print_name
c_e.print_name = "c_e"

variables = {
"Electrolyte concentration": c_e,
"X-averaged electrolyte concentration": pybamm.x_average(c_e),
}
variables = self._get_standard_domain_concentration_variables(c_e_dict)
variables.update(self._get_standard_whole_cell_concentration_variables(c_e))
return variables

def _get_standard_domain_concentration_variables(self, c_e_dict):
c_e_typ = self.param.c_e_typ
variables = {}
# Case where an electrode is not included (half-cell)
if "negative electrode" not in self.options.whole_cell_domains:
c_e_s = c_e_dict["separator"]
Expand Down Expand Up @@ -75,6 +76,24 @@ def _get_standard_concentration_variables(self, c_e_dict):

return variables

def _get_standard_whole_cell_concentration_variables(self, c_e):
c_e_typ = self.param.c_e_typ

variables = {
"Electrolyte concentration": c_e,
"X-averaged electrolyte concentration": pybamm.x_average(c_e),
}
variables_nondim = variables.copy()
for name, var in variables_nondim.items():
variables.update(
{
f"{name} [mol.m-3]": c_e_typ * var,
f"{name} [Molar]": c_e_typ * var / 1000,
}
)

return variables

def _get_standard_porosity_times_concentration_variables(self, eps_c_e_dict):
eps_c_e = pybamm.concatenation(*eps_c_e_dict.values())
variables = {"Porosity times concentration": eps_c_e}
Expand Down
Loading

0 comments on commit c0aa28d

Please sign in to comment.