Skip to content

Commit

Permalink
Merge pull request #881 from pybamm-team/get-function-args
Browse files Browse the repository at this point in the history
Get function args
  • Loading branch information
Scottmar93 authored Mar 30, 2020
2 parents 95b0473 + 7e66610 commit 891424b
Show file tree
Hide file tree
Showing 26 changed files with 373 additions and 122 deletions.
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
- Added functionality to broadcast to edges ([#891](https://github.com/pybamm-team/PyBaMM/pull/891))
- Reformatted and cleaned up `QuickPlot` ([#886](https://github.com/pybamm-team/PyBaMM/pull/886))
- Added thermal effects to lead-acid models ([#885](https://github.com/pybamm-team/PyBaMM/pull/885))
- Add new symbols `VariableDot`, representing the derivative of a variable wrt time,
and `StateVectorDot`, representing the derivative of a state vector wrt time
([#858](https://github.com/pybamm-team/PyBaMM/issues/858))
- Added a helper function for info on function parameters ([#881](https://github.com/pybamm-team/PyBaMM/pull/881))
- Added additional notebooks showing how to create and compare models ([#877](https://github.com/pybamm-team/PyBaMM/pull/877))
- Added `Minimum`, `Maximum` and `Sign` operators
([#876](https://github.com/pybamm-team/PyBaMM/pull/876))
- Added a search feature to `FuzzyDict` ([#875](https://github.com/pybamm-team/PyBaMM/pull/875))
- Add ambient temperature as a function of time ([#872](https://github.com/pybamm-team/PyBaMM/pull/872))
- Added `CasadiAlgebraicSolver` for solving algebraic systems with CasADi ([#868](https://github.com/pybamm-team/PyBaMM/pull/868))
- Added electrolyte functions from Landesfeind ([#860](https://github.com/pybamm-team/PyBaMM/pull/860))
- Add new symbols `VariableDot`, representing the derivative of a variable wrt time,
and `StateVectorDot`, representing the derivative of a state vector wrt time
([#858](https://github.com/pybamm-team/PyBaMM/issues/858))

## Bug fixes

Expand Down
2 changes: 1 addition & 1 deletion examples/notebooks/Creating Models/1-an-ode-model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
"version": "3.7.5"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion examples/notebooks/Creating Models/2-a-pde-model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
"version": "3.7.5"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
"version": "3.7.5"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
"version": "3.7.5"
}
},
"nbformat": 4,
Expand Down
6 changes: 3 additions & 3 deletions examples/notebooks/Creating Models/5-a-simple-SEI-model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@
"c_inf_dim = pybamm.Parameter(\"Bulk electrolyte solvent concentration\")\n",
"\n",
"def D_dim(cc):\n",
" return pybamm.FunctionParameter(\"Diffusivity\", cc)\n",
" return pybamm.FunctionParameter(\"Diffusivity\", {\"Solvent concentration [mol.m-3]\": cc})\n",
"\n",
"# dimensionless parameters\n",
"k = k_dim * L_0_dim / D_dim(c_inf_dim)\n",
Expand Down Expand Up @@ -591,7 +591,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b2a5f649a3b64685ad9510649f130829",
"model_id": "efe1fe18458a42d88056baf689f6da80",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -653,7 +653,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
"version": "3.7.5"
}
},
"nbformat": 4,
Expand Down
18 changes: 9 additions & 9 deletions examples/notebooks/parameter-values.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/scripts/create-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


def D_dim(cc):
return pybamm.FunctionParameter("Diffusivity", cc)
return pybamm.FunctionParameter("Diffusivity", {"Concentration [mol.m-3]": cc})


# dimensionless parameters
Expand Down
65 changes: 56 additions & 9 deletions pybamm/expression_tree/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,23 @@ class FunctionParameter(pybamm.Symbol):
name : str
name of the node
child : :class:`Symbol`
child node
inputs : dict
A dictionary with string keys and :class:`pybamm.Symbol` values representing
the function inputs. The string keys should provide a reasonable description
of what the input to the function is
(e.g. "Electrolyte concentration [mol.m-3]")
diff_variable : :class:`pybamm.Symbol`, optional
if diff_variable is specified, the FunctionParameter node will be replaced by a
:class:`pybamm.Function` and then differentiated with respect to diff_variable.
Default is None.
"""

def __init__(self, name, *children, diff_variable=None):
def __init__(
self, name, inputs, diff_variable=None,
):
# assign diff variable
self.diff_variable = diff_variable
children_list = list(children)
children_list = list(inputs.values())

# Turn numbers into scalars
for idx, child in enumerate(children_list):
Expand All @@ -76,6 +80,37 @@ def __init__(self, name, *children, diff_variable=None):
auxiliary_domains=auxiliary_domains,
)

self.input_names = list(inputs.keys())

@property
def input_names(self):
return self._input_names

def print_input_names(self):
if self._input_names:
for inp in self._input_names:
print(inp)

@input_names.setter
def input_names(self, inp=None):
if inp:
if inp.__class__ is list:
for i in inp:
if i.__class__ is not str:
raise TypeError(
"Inputs must be a provided as"
+ "a dictionary of the form:"
+ "{{str: :class:`pybamm.Symbol`}}"
)
else:
raise TypeError(
"Inputs must be a provided as"
+ " a dictionary of the form:"
+ "{{str: :class:`pybamm.Symbol`}}"
)

self._input_names = inp

def set_id(self):
"""See :meth:`pybamm.Symbol.set_id` """
self._id = hash(
Expand Down Expand Up @@ -107,17 +142,24 @@ def diff(self, variable):
""" See :meth:`pybamm.Symbol.diff()`. """
# return a new FunctionParameter, that knows it will need to be differentiated
# when the parameters are set
return FunctionParameter(self.name, *self.orphans, diff_variable=variable)
children_list = self.orphans
input_names = self._input_names

input_dict = {input_names[i]: children_list[i] for i in range(len(input_names))}

return FunctionParameter(self.name, input_dict, diff_variable=variable)

def new_copy(self):
""" See :meth:`pybamm.Symbol.new_copy()`. """
return self._function_parameter_new_copy(self.orphans)
return self._function_parameter_new_copy(self._input_names, self.orphans)

def _function_parameter_new_copy(self, children):
def _function_parameter_new_copy(self, input_names, children):
"""Returns a new copy of the function parameter.
Inputs
------
input_names : : list
A list of str of the names of the children/function inputs
children : : list
A list of the children of the function
Expand All @@ -126,7 +168,12 @@ def _function_parameter_new_copy(self, children):
: :pybamm.FunctionParameter
A new copy of the function parameter
"""
return FunctionParameter(self.name, *children, diff_variable=self.diff_variable)

input_dict = {input_names[i]: children[i] for i in range(len(input_names))}

return FunctionParameter(
self.name, input_dict, diff_variable=self.diff_variable
)

def _evaluate_for_shape(self):
"""
Expand Down
85 changes: 75 additions & 10 deletions pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,13 +398,13 @@ def check_for_time_derivatives(self):
for node in eq.pre_order():
if isinstance(node, pybamm.VariableDot):
raise pybamm.ModelError(
"time derivative of variable found ({}) in rhs equation {}"
.format(node, key)
"time derivative of variable"
+ " found ({}) in rhs equation {}".format(node, key)
)
if isinstance(node, pybamm.StateVectorDot):
raise pybamm.ModelError(
"time derivative of state vector found ({}) in rhs equation {}"
.format(node, key)
"time derivative of state vector"
+ " found ({}) in rhs equation {}".format(node, key)
)

# Check that no variable time derivatives exist in the algebraic equations
Expand Down Expand Up @@ -441,8 +441,11 @@ def check_well_determined(self, post_discretisation):
[x.id for x in eqn.pre_order() if isinstance(x, pybamm.Variable)]
)
vars_in_eqns.update(
[x.get_variable().id for x in eqn.pre_order()
if isinstance(x, pybamm.VariableDot)]
[
x.get_variable().id
for x in eqn.pre_order()
if isinstance(x, pybamm.VariableDot)
]
)
for var, eqn in self.algebraic.items():
vars_in_algebraic_keys.update(
Expand All @@ -452,17 +455,23 @@ def check_well_determined(self, post_discretisation):
[x.id for x in eqn.pre_order() if isinstance(x, pybamm.Variable)]
)
vars_in_eqns.update(
[x.get_variable().id for x in eqn.pre_order()
if isinstance(x, pybamm.VariableDot)]
[
x.get_variable().id
for x in eqn.pre_order()
if isinstance(x, pybamm.VariableDot)
]
)
for var, side_eqn in self.boundary_conditions.items():
for side, (eqn, typ) in side_eqn.items():
vars_in_eqns.update(
[x.id for x in eqn.pre_order() if isinstance(x, pybamm.Variable)]
)
vars_in_eqns.update(
[x.get_variable().id for x in eqn.pre_order()
if isinstance(x, pybamm.VariableDot)]
[
x.get_variable().id
for x in eqn.pre_order()
if isinstance(x, pybamm.VariableDot)
]
)
# If any keys are repeated between rhs and algebraic then the model is
# overdetermined
Expand Down Expand Up @@ -614,6 +623,32 @@ def check_variables(self):
)
)

def info(self, symbol_name):
"""
Provides helpful summary information for a symbol.
Parameters
----------
parameter_name : str
"""

div = "-----------------------------------------"
symbol = find_symbol_in_model(self, symbol_name)

if not symbol:
return None

print(div)
print(symbol_name, "\n")
print(type(symbol))

if isinstance(symbol, pybamm.FunctionParameter):
print("")
print("Inputs:")
symbol.print_input_names()

print(div)

@property
def default_solver(self):
"Return default solver based on whether model is ODE model or DAE model"
Expand All @@ -624,3 +659,33 @@ def default_solver(self):
return pybamm.IDAKLUSolver()
else:
return pybamm.CasadiSolver(mode="safe")


# helper functions for finding symbols
def find_symbol_in_tree(tree, name):
if name == tree.name:
return tree
elif len(tree.children) > 0:
for child in tree.children:
child_return = find_symbol_in_tree(child, name)
if child_return:
return child_return


def find_symbol_in_dict(dic, name):
for tree in dic.values():
tree_return = find_symbol_in_tree(tree, name)
if tree_return:
return tree_return


def find_symbol_in_model(model, name):
dics = [
model.rhs,
model.algebraic,
model.variables,
]
for dic in dics:
dic_return = find_symbol_in_dict(dic, name)
if dic_return:
return dic_return
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, param):

def constant_voltage(variables):
V = variables["Terminal voltage [V]"]
return V - pybamm.FunctionParameter("Voltage function [V]", pybamm.t)
return V - pybamm.FunctionParameter("Voltage function [V]", {"Time [s]": pybamm.t})


class PowerFunctionControl(FunctionControl):
Expand All @@ -71,7 +71,9 @@ def __init__(self, param):
def constant_power(variables):
I = variables["Current [A]"]
V = variables["Terminal voltage [V]"]
return I * V - pybamm.FunctionParameter("Power function [W]", pybamm.t)
return I * V - pybamm.FunctionParameter(
"Power function [W]", {"Time [s]": pybamm.t}
)


class LeadingOrderFunctionControl(FunctionControl, LeadingOrderBaseModel):
Expand Down
10 changes: 8 additions & 2 deletions pybamm/models/submodels/particle/fickian_many_particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,18 @@ def get_coupled_variables(self, variables):

if self.domain == "Negative":
x = pybamm.standard_spatial_vars.x_n
R = pybamm.FunctionParameter("Negative particle distribution in x", x)
R = pybamm.FunctionParameter(
"Negative particle distribution in x",
{"Dimensionless through-cell position (x_n)": x},
)
variables.update({"Negative particle distribution in x": R})

elif self.domain == "Positive":
x = pybamm.standard_spatial_vars.x_p
R = pybamm.FunctionParameter("Positive particle distribution in x", x)
R = pybamm.FunctionParameter(
"Positive particle distribution in x",
{"Dimensionless through-cell position (x_p)": x},
)
variables.update({"Positive particle distribution in x": R})

return variables
Expand Down
2 changes: 1 addition & 1 deletion pybamm/parameters/electrical_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# the user may provide the typical timescale as a parameter.
timescale = pybamm.Parameter("Typical timescale [s]")
dimensional_current_with_time = pybamm.FunctionParameter(
"Current function [A]", pybamm.t * timescale
"Current function [A]", {"Time[s]": pybamm.t * timescale}
)
dimensional_current_density_with_time = dimensional_current_with_time / (
n_electrodes_parallel * pybamm.geometric_parameters.A_cc
Expand Down
Loading

0 comments on commit 891424b

Please sign in to comment.