Skip to content

Commit

Permalink
#1129 fix mtk model generation
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Dec 2, 2021
1 parent 61b6f13 commit 802966d
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 23 deletions.
15 changes: 5 additions & 10 deletions pybamm/expression_tree/operations/evaluate_julia.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def find_symbols(symbol, constant_symbols, variable_symbols, variable_symbol_siz
symbol_str = "inputs['{}']".format(symbol.name)

elif isinstance(symbol, pybamm.SpatialVariable):
symbol_str = symbol.name.replace("_", "")
symbol_str = symbol.name

elif isinstance(symbol, pybamm.FunctionParameter):
symbol_str = "{}({})".format(symbol.name, ", ".join(children_vars))
Expand Down Expand Up @@ -773,7 +773,7 @@ def get_julia_mtk_model(model, geometry=None, tspan=None):
variable_id_to_print_name = {}
for i, var in enumerate(variables):
if var.print_name is not None:
print_name = var.print_name
print_name = var._raw_print_name
else:
print_name = f"u{i+1}"
variable_id_to_print_name[var.id] = print_name
Expand Down Expand Up @@ -808,24 +808,19 @@ def get_julia_mtk_model(model, geometry=None, tspan=None):
long_domain_symbol_to_short = {}
for dom in all_domains:
# Read domain name from geometry
domain_symbol = list(geometry[dom[0]].keys())[0].name.replace("_", "")
domain_symbol = list(geometry[dom[0]].keys())[0]
if len(dom) > 1:
domain_symbol = domain_symbol[0]
# For multi-domain variables keep only the first letter of the domain
domain_name_to_symbol[tuple(dom)] = domain_symbol
# Record which domain symbols we shortened
for d in dom:
long = list(geometry[d].keys())[0].name.replace("_", "")
long = list(geometry[d].keys())[0]
long_domain_symbol_to_short[long] = domain_symbol
else:
# Otherwise keep the whole domain
domain_name_to_symbol[tuple(dom)] = domain_symbol

# Read coordinate systems
domain_name_to_coord_sys = {
tuple(dom): list(geometry[dom[0]].keys())[0].coord_sys for dom in all_domains
}

# Read domain limits
domain_name_to_limits = {}
for dom in all_domains:
Expand Down Expand Up @@ -958,7 +953,7 @@ def get_julia_mtk_model(model, geometry=None, tspan=None):
f"grad_{domain_name}", f"D{domain_symbol}"
)
# Different divergence depending on the coordinate system
coord_sys = domain_name_to_coord_sys[domain_name]
coord_sys = getattr(pybamm.standard_spatial_vars, domain_symbol).coord_sys
if coord_sys == "cartesian":
all_julia_str = all_julia_str.replace(
f"div_{domain_name}", f"D{domain_symbol}"
Expand Down
30 changes: 21 additions & 9 deletions pybamm/expression_tree/printing/print_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def prettify_print_name(name):
"""Prettify print_name using regex"""

# Skip prettify_print_name() for cases like `new_copy()`
if "{" in name or "\\" in name:
if name is None or "{" in name or "\\" in name:
return name

# Return print_name if name exists in the dictionary
Expand All @@ -56,8 +56,9 @@ def prettify_print_name(name):
# Superscripts with comma separated (U_n_ref --> U_{n}^{ref})
sup_re = re.search(r"^[\da-zA-Z]+_?(.*?)_?((?:init|ref|typ|max|0))", name)
if sup_re:
sup_str = (r"{" + sup_re.group(1).replace("_", "\,") + r"}^{" +
sup_re.group(2) + r"}")
sup_str = (
r"{" + sup_re.group(1).replace("_", "\,") + r"}^{" + sup_re.group(2) + r"}"
)
sup_var = sup_re.group(1) + "_" + sup_re.group(2)
name = name.replace(sup_var, sup_str)

Expand All @@ -71,17 +72,28 @@ def prettify_print_name(name):
dim_re = re.search(r"([\da-zA-Z]+)_?(.*?)_?(?:dim|dimensional)", name)
if dim_re:
if "^" in name:
name = (r"\hat{" + dim_re.group(1) + r"}_" +
dim_re.group(2).replace("_", "\,"))
name = (
r"\hat{" + dim_re.group(1) + r"}_" + dim_re.group(2).replace("_", "\,")
)
else:
name = (r"\hat{" + dim_re.group(1) + r"}_{" +
dim_re.group(2).replace("_", "\,") + r"}")
name = (
r"\hat{"
+ dim_re.group(1)
+ r"}_{"
+ dim_re.group(2).replace("_", "\,")
+ r"}"
)

# Bar with comma separated (c_s_n_xav --> \bar{c}_{s\,n})
bar_re = re.search(r"^([a-zA-Z]+)_*(\w*?)_(?:av|xav)", name)
if bar_re:
name = (r"\bar{" + bar_re.group(1) + r"}_{" +
bar_re.group(2).replace("_", "\,") + r"}")
name = (
r"\bar{"
+ bar_re.group(1)
+ r"}_{"
+ bar_re.group(2).replace("_", "\,")
+ r"}"
)

# Replace eps with epsilon (eps_n --> epsilon_n)
name = re.sub(r"(eps)(?![0-9a-zA-Z])", "epsilon", name)
Expand Down
6 changes: 2 additions & 4 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,10 +979,8 @@ def print_name(self):

@print_name.setter
def print_name(self, name):
if name is None:
self._print_name = name
else:
self._print_name = prettify_print_name(name)
self._raw_print_name = name
self._print_name = prettify_print_name(name)

def to_equation(self):
return sympy.Symbol(str(self.name))

0 comments on commit 802966d

Please sign in to comment.