Skip to content

Commit

Permalink
#705 fix disc of spatial var
Browse files Browse the repository at this point in the history
  • Loading branch information
rtimms committed Nov 6, 2019
1 parent e4cf50e commit 3a2c037
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 21 deletions.
6 changes: 4 additions & 2 deletions pybamm/spatial_methods/scikit_finite_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@ def spatial_variable(self, symbol):
symbol_mesh = self.mesh
if symbol.name == "y":
vector = pybamm.Vector(
symbol_mesh["current collector"][0].edges["y"], domain=symbol.domain
symbol_mesh["current collector"][0].coordinates[0, :][:, np.newaxis]
# symbol_mesh["current collector"][0].edges["y"], domain=symbol.domain
)
elif symbol.name == "z":
vector = pybamm.Vector(
symbol_mesh["current collector"][0].edges["z"], domain=symbol.domain
symbol_mesh["current collector"][0].coordinates[1, :][:, np.newaxis]
# symbol_mesh["current collector"][0].edges["z"], domain=symbol.domain
)
else:
raise pybamm.GeometryError(
Expand Down
28 changes: 10 additions & 18 deletions tests/unit/test_processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,39 +140,35 @@ def test_processed_variable_3D_x_z(self):

def test_processed_variable_3D_scikit(self):
var = pybamm.Variable("var", domain=["current collector"])
y = pybamm.SpatialVariable("y", domain=["current collector"])
z = pybamm.SpatialVariable("z", domain=["current collector"])

disc = tests.get_2p1d_discretisation_for_testing()
disc.set_variable_slices([var])
y_sol = disc.process_symbol(y).entries[:, 0]
z_sol = disc.process_symbol(z).entries[:, 0]
y = disc.mesh["current collector"][0].edges["y"]
z = disc.mesh["current collector"][0].edges["z"]
var_sol = disc.process_symbol(var)
t_sol = np.linspace(0, 1)
u_sol = np.ones(var_sol.shape[0])[:, np.newaxis] * np.linspace(0, 5)

processed_var = pybamm.ProcessedVariable(var_sol, t_sol, u_sol, mesh=disc.mesh)
np.testing.assert_array_equal(
processed_var.entries,
np.reshape(u_sol, [len(y_sol), len(z_sol), len(t_sol)]),
np.reshape(u_sol, [len(y), len(z), len(t_sol)]),
)

def test_processed_variable_2Dspace_scikit(self):
var = pybamm.Variable("var", domain=["current collector"])
y = pybamm.SpatialVariable("y", domain=["current collector"])
z = pybamm.SpatialVariable("z", domain=["current collector"])

disc = tests.get_2p1d_discretisation_for_testing()
disc.set_variable_slices([var])
y_sol = disc.process_symbol(y).entries[:, 0]
z_sol = disc.process_symbol(z).entries[:, 0]
y = disc.mesh["current collector"][0].edges["y"]
z = disc.mesh["current collector"][0].edges["z"]
var_sol = disc.process_symbol(var)
t_sol = np.array([0])
u_sol = np.ones(var_sol.shape[0])[:, np.newaxis]

processed_var = pybamm.ProcessedVariable(var_sol, t_sol, u_sol, mesh=disc.mesh)
np.testing.assert_array_equal(
processed_var.entries, np.reshape(u_sol, [len(y_sol), len(z_sol)])
processed_var.entries, np.reshape(u_sol, [len(y), len(z)])
)

def test_processed_var_1D_interpolation(self):
Expand Down Expand Up @@ -367,13 +363,11 @@ def test_processed_var_3D_r_first_dimension(self):

def test_processed_var_3D_scikit_interpolation(self):
var = pybamm.Variable("var", domain=["current collector"])
y = pybamm.SpatialVariable("y", domain=["current collector"])
z = pybamm.SpatialVariable("z", domain=["current collector"])

disc = tests.get_2p1d_discretisation_for_testing()
disc.set_variable_slices([var])
y_sol = disc.process_symbol(y).entries[:, 0]
z_sol = disc.process_symbol(z).entries[:, 0]
y_sol = disc.mesh["current collector"][0].edges["y"]
z_sol = disc.mesh["current collector"][0].edges["z"]
var_sol = disc.process_symbol(var)
t_sol = np.linspace(0, 1)
u_sol = np.ones(var_sol.shape[0])[:, np.newaxis] * np.linspace(0, 5)
Expand Down Expand Up @@ -406,13 +400,11 @@ def test_processed_var_3D_scikit_interpolation(self):

def test_processed_var_2Dspace_scikit_interpolation(self):
var = pybamm.Variable("var", domain=["current collector"])
y = pybamm.SpatialVariable("y", domain=["current collector"])
z = pybamm.SpatialVariable("z", domain=["current collector"])

disc = tests.get_2p1d_discretisation_for_testing()
disc.set_variable_slices([var])
y_sol = disc.process_symbol(y).entries[:, 0]
z_sol = disc.process_symbol(z).entries[:, 0]
y_sol = disc.mesh["current collector"][0].edges["y"]
z_sol = disc.mesh["current collector"][0].edges["z"]
var_sol = disc.process_symbol(var)
t_sol = np.array([0])
u_sol = np.ones(var_sol.shape[0])[:, np.newaxis]
Expand Down
27 changes: 26 additions & 1 deletion tests/unit/test_spatial_methods/test_scikit_finite_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,11 +493,36 @@ def test_dirichlet_bcs(self):
solver = pybamm.AlgebraicSolver()
solution = solver.solve(model)

# indepdent of y, so just check values for one y
# indepedent of y, so just check values for one y
z = mesh["current collector"][0].edges["z"][:, np.newaxis]
u_exact = a * z ** 2 + b * z + c
np.testing.assert_array_almost_equal(solution.y[0 : len(z)], u_exact)

def test_disc_spatial_var(self):
mesh = get_unit_2p1D_mesh_for_testing(ypts=4, zpts=5)
spatial_methods = {
"macroscale": pybamm.FiniteVolume,
"current collector": pybamm.ScikitFiniteElement,
}
disc = pybamm.Discretisation(mesh, spatial_methods)

# discretise y and z
y = pybamm.SpatialVariable("y", ["current collector"])
z = pybamm.SpatialVariable("z", ["current collector"])
y_disc = disc.process_symbol(y)
z_disc = disc.process_symbol(z)

# create expected meshgrid
y_vec = np.linspace(0, 1, 4)
z_vec = np.linspace(0, 1, 5)
Y, Z = np.meshgrid(y_vec, z_vec)
y_actual = np.transpose(Y).flatten()[:, np.newaxis]
z_actual = np.transpose(Z).flatten()[:, np.newaxis]

# spatial vars should discretise to the flattend meshgrid
np.testing.assert_array_equal(y_disc.evaluate(), y_actual)
np.testing.assert_array_equal(z_disc.evaluate(), z_actual)


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down

0 comments on commit 3a2c037

Please sign in to comment.