From 3a2c03700a31fc8ccdf6a002a52865f3e328cea6 Mon Sep 17 00:00:00 2001 From: Robert Timms Date: Wed, 6 Nov 2019 10:45:02 +0000 Subject: [PATCH] #705 fix disc of spatial var --- .../spatial_methods/scikit_finite_element.py | 6 ++-- tests/unit/test_processed_variable.py | 28 +++++++------------ .../test_scikit_finite_element.py | 27 +++++++++++++++++- 3 files changed, 40 insertions(+), 21 deletions(-) diff --git a/pybamm/spatial_methods/scikit_finite_element.py b/pybamm/spatial_methods/scikit_finite_element.py index 5d8331e489..f6e0832cec 100644 --- a/pybamm/spatial_methods/scikit_finite_element.py +++ b/pybamm/spatial_methods/scikit_finite_element.py @@ -51,11 +51,13 @@ def spatial_variable(self, symbol): symbol_mesh = self.mesh if symbol.name == "y": vector = pybamm.Vector( - symbol_mesh["current collector"][0].edges["y"], domain=symbol.domain + symbol_mesh["current collector"][0].coordinates[0, :][:, np.newaxis] + # symbol_mesh["current collector"][0].edges["y"], domain=symbol.domain ) elif symbol.name == "z": vector = pybamm.Vector( - symbol_mesh["current collector"][0].edges["z"], domain=symbol.domain + symbol_mesh["current collector"][0].coordinates[1, :][:, np.newaxis] + # symbol_mesh["current collector"][0].edges["z"], domain=symbol.domain ) else: raise pybamm.GeometryError( diff --git a/tests/unit/test_processed_variable.py b/tests/unit/test_processed_variable.py index a891fdbc6b..87c29a84a0 100644 --- a/tests/unit/test_processed_variable.py +++ b/tests/unit/test_processed_variable.py @@ -140,13 +140,11 @@ def test_processed_variable_3D_x_z(self): def test_processed_variable_3D_scikit(self): var = pybamm.Variable("var", domain=["current collector"]) - y = pybamm.SpatialVariable("y", domain=["current collector"]) - z = pybamm.SpatialVariable("z", domain=["current collector"]) disc = tests.get_2p1d_discretisation_for_testing() disc.set_variable_slices([var]) - y_sol = disc.process_symbol(y).entries[:, 0] - z_sol = disc.process_symbol(z).entries[:, 0] + y = disc.mesh["current collector"][0].edges["y"] + z = disc.mesh["current collector"][0].edges["z"] var_sol = disc.process_symbol(var) t_sol = np.linspace(0, 1) u_sol = np.ones(var_sol.shape[0])[:, np.newaxis] * np.linspace(0, 5) @@ -154,25 +152,23 @@ def test_processed_variable_3D_scikit(self): processed_var = pybamm.ProcessedVariable(var_sol, t_sol, u_sol, mesh=disc.mesh) np.testing.assert_array_equal( processed_var.entries, - np.reshape(u_sol, [len(y_sol), len(z_sol), len(t_sol)]), + np.reshape(u_sol, [len(y), len(z), len(t_sol)]), ) def test_processed_variable_2Dspace_scikit(self): var = pybamm.Variable("var", domain=["current collector"]) - y = pybamm.SpatialVariable("y", domain=["current collector"]) - z = pybamm.SpatialVariable("z", domain=["current collector"]) disc = tests.get_2p1d_discretisation_for_testing() disc.set_variable_slices([var]) - y_sol = disc.process_symbol(y).entries[:, 0] - z_sol = disc.process_symbol(z).entries[:, 0] + y = disc.mesh["current collector"][0].edges["y"] + z = disc.mesh["current collector"][0].edges["z"] var_sol = disc.process_symbol(var) t_sol = np.array([0]) u_sol = np.ones(var_sol.shape[0])[:, np.newaxis] processed_var = pybamm.ProcessedVariable(var_sol, t_sol, u_sol, mesh=disc.mesh) np.testing.assert_array_equal( - processed_var.entries, np.reshape(u_sol, [len(y_sol), len(z_sol)]) + processed_var.entries, np.reshape(u_sol, [len(y), len(z)]) ) def test_processed_var_1D_interpolation(self): @@ -367,13 +363,11 @@ def test_processed_var_3D_r_first_dimension(self): def test_processed_var_3D_scikit_interpolation(self): var = pybamm.Variable("var", domain=["current collector"]) - y = pybamm.SpatialVariable("y", domain=["current collector"]) - z = pybamm.SpatialVariable("z", domain=["current collector"]) disc = tests.get_2p1d_discretisation_for_testing() disc.set_variable_slices([var]) - y_sol = disc.process_symbol(y).entries[:, 0] - z_sol = disc.process_symbol(z).entries[:, 0] + y_sol = disc.mesh["current collector"][0].edges["y"] + z_sol = disc.mesh["current collector"][0].edges["z"] var_sol = disc.process_symbol(var) t_sol = np.linspace(0, 1) u_sol = np.ones(var_sol.shape[0])[:, np.newaxis] * np.linspace(0, 5) @@ -406,13 +400,11 @@ def test_processed_var_3D_scikit_interpolation(self): def test_processed_var_2Dspace_scikit_interpolation(self): var = pybamm.Variable("var", domain=["current collector"]) - y = pybamm.SpatialVariable("y", domain=["current collector"]) - z = pybamm.SpatialVariable("z", domain=["current collector"]) disc = tests.get_2p1d_discretisation_for_testing() disc.set_variable_slices([var]) - y_sol = disc.process_symbol(y).entries[:, 0] - z_sol = disc.process_symbol(z).entries[:, 0] + y_sol = disc.mesh["current collector"][0].edges["y"] + z_sol = disc.mesh["current collector"][0].edges["z"] var_sol = disc.process_symbol(var) t_sol = np.array([0]) u_sol = np.ones(var_sol.shape[0])[:, np.newaxis] diff --git a/tests/unit/test_spatial_methods/test_scikit_finite_element.py b/tests/unit/test_spatial_methods/test_scikit_finite_element.py index bf9a93a535..21bd4b8799 100644 --- a/tests/unit/test_spatial_methods/test_scikit_finite_element.py +++ b/tests/unit/test_spatial_methods/test_scikit_finite_element.py @@ -493,11 +493,36 @@ def test_dirichlet_bcs(self): solver = pybamm.AlgebraicSolver() solution = solver.solve(model) - # indepdent of y, so just check values for one y + # indepedent of y, so just check values for one y z = mesh["current collector"][0].edges["z"][:, np.newaxis] u_exact = a * z ** 2 + b * z + c np.testing.assert_array_almost_equal(solution.y[0 : len(z)], u_exact) + def test_disc_spatial_var(self): + mesh = get_unit_2p1D_mesh_for_testing(ypts=4, zpts=5) + spatial_methods = { + "macroscale": pybamm.FiniteVolume, + "current collector": pybamm.ScikitFiniteElement, + } + disc = pybamm.Discretisation(mesh, spatial_methods) + + # discretise y and z + y = pybamm.SpatialVariable("y", ["current collector"]) + z = pybamm.SpatialVariable("z", ["current collector"]) + y_disc = disc.process_symbol(y) + z_disc = disc.process_symbol(z) + + # create expected meshgrid + y_vec = np.linspace(0, 1, 4) + z_vec = np.linspace(0, 1, 5) + Y, Z = np.meshgrid(y_vec, z_vec) + y_actual = np.transpose(Y).flatten()[:, np.newaxis] + z_actual = np.transpose(Z).flatten()[:, np.newaxis] + + # spatial vars should discretise to the flattend meshgrid + np.testing.assert_array_equal(y_disc.evaluate(), y_actual) + np.testing.assert_array_equal(z_disc.evaluate(), z_actual) + if __name__ == "__main__": print("Add -v for more debug output")