Skip to content

Commit

Permalink
#632 fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Jun 12, 2020
1 parent 45de6f4 commit 3edd39a
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 23 deletions.
42 changes: 25 additions & 17 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pybamm
import numpy as np
from collections import defaultdict, OrderedDict
from scipy.sparse import block_diag, csc_matrix, csr_matrix
from scipy.sparse import block_diag, csc_matrix, csr_matrix, issparse
from scipy.sparse.linalg import inv


Expand Down Expand Up @@ -1106,16 +1106,14 @@ def check_initial_conditions_rhs(self, model):
y0 = model.concatenated_initial_conditions
# Individual
for var in model.rhs.keys():
assert (
model.rhs[var].shape == model.initial_conditions[var].shape
), pybamm.ModelError(
"""
rhs and initial_conditions must have the same shape after discretisation
but rhs.shape = {} and initial_conditions.shape = {} for variable '{}'.
""".format(
model.rhs[var].shape, model.initial_conditions[var].shape, var
if not model.rhs[var].shape == model.initial_conditions[var].shape:
raise pybamm.ModelError(
"rhs and initial_conditions must have the same shape after "
"discretisation but rhs.shape = "
"{} and initial_conditions.shape = {} for variable '{}'.".format(
model.rhs[var].shape, model.initial_conditions[var].shape, var
)
)
)
# Concatenated
assert (
model.concatenated_rhs.shape[0] + model.concatenated_algebraic.shape[0]
Expand Down Expand Up @@ -1150,17 +1148,27 @@ def check_variables(self, model):
not_concatenation = not isinstance(var, pybamm.Concatenation)

not_mult_by_one_vec = not (
isinstance(var, pybamm.Multiplication)
and isinstance(var.right, pybamm.Vector)
and np.all(var.right.entries == 1)
isinstance(
var, (pybamm.Multiplication, pybamm.MatrixMultiplication)
)
and (is_array_one(var.left) or is_array_one(var.right))
)

if different_shapes and not_concatenation and not_mult_by_one_vec:
raise pybamm.ModelError(
"""
variable and its eqn must have the same shape after discretisation
but variable.shape = {} and rhs.shape = {} for variable '{}'.
""".format(
"variable and its eqn must have the same shape after "
"discretisation but variable.shape = "
"{} and rhs.shape = {} for variable '{}'. ".format(
var.shape, model.rhs[rhs_var].shape, var
)
)


def is_array_one(symbol):
if not isinstance(symbol, pybamm.Array):
return False
entries = symbol.entries
if issparse(entries):
return np.all(entries.toarray() == 1)
else:
return np.all(entries == 1)
2 changes: 1 addition & 1 deletion pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def evaluates_to_number(self):
result = self.evaluate_ignoring_errors()

if isinstance(result, numbers.Number) or (
isinstance(result, np.ndarray) and result.shape == ()
isinstance(result, np.ndarray) and np.prod(result.shape) == 1
):
return True
else:
Expand Down
2 changes: 1 addition & 1 deletion pybamm/models/submodels/interface/diffusion_limited.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _get_j_diffusion_limited_first_order(self, variables):
param = self.param
if self.domain == "Negative":
N_ox_s_p = variables["Oxygen flux"].orphans[1]
N_ox_neg_sep_interface = N_ox_s_p[0]
N_ox_neg_sep_interface = pybamm.Index(N_ox_s_p, slice(0, 1))

j = -N_ox_neg_sep_interface / param.C_e / -param.s_ox_Ox / param.l_n

Expand Down
8 changes: 4 additions & 4 deletions pybamm/spatial_methods/scikit_finite_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def integral(self, child, discretised_child, integration_dimension):

return out

def definite_integral_matrix(self, domains, vector_type="row"):
def definite_integral_matrix(self, child, vector_type="row"):
"""
Matrix for finite-element implementation of the definite integral over
the entire domain
Expand All @@ -315,8 +315,8 @@ def definite_integral_matrix(self, domains, vector_type="row"):
Parameters
----------
domains : dict
The domain(s) of integration
child : :class:`pybamm.Symbol`
The symbol being integrated
vector_type : str, optional
Whether to return a row or column vector (default is row)
Expand All @@ -326,7 +326,7 @@ def definite_integral_matrix(self, domains, vector_type="row"):
The finite element integral vector for the domain
"""
# get primary domain mesh
domain = domains["primary"]
domain = child.domains["primary"]
if isinstance(domain, list):
domain = domain[0]
mesh = self.mesh[domain]
Expand Down

0 comments on commit 3edd39a

Please sign in to comment.