diff --git a/skfem/assembly/global_basis/interior_basis.py b/skfem/assembly/global_basis/interior_basis.py index 97a63ba27..b47be172c 100644 --- a/skfem/assembly/global_basis/interior_basis.py +++ b/skfem/assembly/global_basis/interior_basis.py @@ -1,6 +1,7 @@ -from typing import Optional - import numpy as np + +from typing import Optional, Callable + from numpy import ndarray from skfem.quadrature import get_quadrature @@ -122,3 +123,20 @@ def refinterp(self, interp: ndarray, Nrefs: Optional[int] = 1): M = meshclass(p, t, validate=False) return M, w.flatten() + + def interpolator(self, y: ndarray) -> Callable[[ndarray], ndarray]: + """Return a function handle, which can be used for finding + pointwise values of the given solution vector.""" + + finder = self.mesh.element_finder() + + def interpfun(x): + tris = finder(*x) + pts = self.mapping.invF(x[:, :, np.newaxis], tind=tris) + w = np.zeros(x.shape[1]) + for k in range(self.Nbfun): + phi = self.elem.gbasis(self.mapping, pts, k, tind=tris) + w += y[self.element_dofs[k, tris]]*phi[0].flatten() + return w + + return interpfun diff --git a/skfem/element.py b/skfem/element.py index 49b06b866..4f410efdc 100644 --- a/skfem/element.py +++ b/skfem/element.py @@ -15,7 +15,9 @@ import numpy as np -from typing import Tuple, Union, List +from typing import Optional, Tuple, Union, List + +from numpy import ndarray class Element(): nodal_dofs: int = 0 @@ -35,6 +37,39 @@ def orient(self, mapping, i, tind=None): else: return 1 + 0*tind + def gbasis(self, + mapping, + X: ndarray, + i: int, + tind: Optional[ndarray] = None) -> Union[Tuple[ndarray, ndarray], + Tuple[ndarray, ndarray, ndarray]]: + """Evaluate the global basis functions, given local points X. + + The global points - at which the global basis is evaluated at - + are defined through x = F(X), where F corresponds to the given mapping. + + Parameters + ---------- + mapping + Local-to-global mapping, an object of type :class:`~skfem.mapping.Mapping`. + X + An array of local points. The following shapes are supported: (Ndim x Npoints) + and (Ndim x Nelems x Npoints), i.e. local points shared by all elements + or different local points in each element. + i + Only the i'th basis function is evaluated. + tind + Optionally, choose a subset of elements to evaluate the basis at. + + Returns + ------- + (u, du) or (u, du, ddu) + The number of return arguments depends on the length of self.order. + The shape of k'th return argument depends on the value of self.order[k]. + + """ + raise NotImplementedError("Element must implement gbasis.") + class ElementH1(Element): order = (0, 1) @@ -83,7 +118,9 @@ class ElementHdiv(Element): order = (1, 0) def orient(self, mapping, i, tind=None): - # TODO fix tind + if tind is not None: + # TODO fix + raise NotImplementedError("TODO: fix tind support in ElementHdiv") return -1 + 2*(mapping.mesh.f2t[0, mapping.mesh.t2f[i, :]] \ == np.arange(mapping.mesh.t.shape[1])) @@ -102,10 +139,13 @@ def lbasis(self, X, i): class ElementHcurl(Element): """Note: only 3D support. Piola transformation is different in 2D.""" + order = (1, 1) def orient(self, mapping, i, tind=None): - # TODO fix tind + if tind is not None: + # TODO fix + raise NotImplementedError("TODO: fix tind support in ElementHcurl") t1 = [0, 1, 0, 0, 1, 2][i] t2 = [1, 2, 2, 3, 3, 3][i] return 1 - 2*(mapping.mesh.t[t1, :] > mapping.mesh.t[t2, :]) @@ -124,17 +164,23 @@ def lbasis(self, X, i): class ElementH2(Element): + """Elements defined implicitly through global degrees-of-freedom.""" + order = (0, 1, 2) V = None # For caching inverse Vandermonde matrix def gbasis(self, mapping, X, i, tind=None): + if tind is None: + tind = np.arange(mapping.mesh.t.shape[1]) # initialize power basis self._pbasis_init(self.maxdeg) N = len(self._pbasis) if self.V is None: # construct Vandermonde matrix and invert it - self.V = np.linalg.inv(self._eval_dofs(mapping.mesh, tind=tind)) + self.V = np.linalg.inv(self._eval_dofs(mapping.mesh)) + + V = self.V[tind] x = mapping.F(X, tind=tind) u = np.zeros(x[0].shape) @@ -143,17 +189,17 @@ def gbasis(self, mapping, X, i, tind=None): # loop over new basis for itr in range(N): - u += self.V[:, itr, i][:, None]\ + u += V[:, itr, i][:, None]\ * self._pbasis[itr](x[0], x[1]) - du[0] += self.V[:, itr, i][:, None]\ + du[0] += V[:, itr, i][:, None]\ * self._pbasisdx[itr](x[0], x[1]) - du[1] += self.V[:, itr, i][:,None]\ + du[1] += V[:, itr, i][:,None]\ * self._pbasisdy[itr](x[0], x[1]) - ddu[0, 0] += self.V[:, itr, i][:, None]\ + ddu[0, 0] += V[:, itr, i][:, None]\ * self._pbasisdxx[itr](x[0], x[1]) - ddu[0, 1] += self.V[:, itr, i][:, None]\ + ddu[0, 1] += V[:, itr, i][:, None]\ * self._pbasisdxy[itr](x[0], x[1]) - ddu[1, 1] += self.V[:, itr, i][:, None]\ + ddu[1, 1] += V[:, itr, i][:, None]\ * self._pbasisdyy[itr](x[0], x[1]) # dxy = dyx diff --git a/skfem/mesh/mesh.py b/skfem/mesh/mesh.py index e483bf6cd..863d24aa8 100644 --- a/skfem/mesh/mesh.py +++ b/skfem/mesh/mesh.py @@ -4,7 +4,9 @@ from .submesh import Submesh -from typing import Dict, Optional, Tuple, Type, TypeVar, Union +from typing import Dict, Optional, Tuple,\ + Type, TypeVar, Union,\ + Callable from numpy import ndarray MeshType = TypeVar('MeshType', bound='Mesh') @@ -295,3 +297,9 @@ def boundary_facets(self) -> ndarray: def interior_facets(self) -> ndarray: """Return an array of interior facet indices.""" return np.nonzero(self.f2t[1, :] >= 0)[0] + + def element_finder(self) -> Callable[[ndarray], ndarray]: + """Return a function, which returns element + indices corresponding to the input points.""" + raise NotImplementError("element_finder not implemented" +\ + "for the given Mesh type.") diff --git a/skfem/mesh/mesh2d/mesh_tri.py b/skfem/mesh/mesh2d/mesh_tri.py index 6689cd6f5..0661e8ce5 100644 --- a/skfem/mesh/mesh2d/mesh_tri.py +++ b/skfem/mesh/mesh2d/mesh_tri.py @@ -260,34 +260,6 @@ def _build_mappings(self, sort_t=True): # second row to zero if repeated (i.e., on boundary) self.f2t[1, np.nonzero(self.f2t[0, :] == self.f2t[1, :])[0]] = -1 - def interpolator(self, x): - """Return a function which interpolates values with P1 basis.""" - triang = mtri.Triangulation(self.p[0, :], self.p[1, :], self.t.T) - interpf = mtri.LinearTriInterpolator(triang, x) - # contruct an interpolator handle - def handle(X, Y): - return interpf(X, Y).data - return handle - - def const_interpolator(self, x): - """Return a function which interpolates values with P0 basis.""" - triang = mtri.Triangulation(self.p[0, :], self.p[1, :], self.t.T) - finder = triang.get_trifinder() - # construct an interpolator handle - def handle(X, Y): - return x[finder(X, Y)] - return handle - - def draw_debug(self): - """Draw without mesh.facets. For debugging self.draw().""" - fig = plt.figure() - plt.hold('on') - for itr in range(self.t.shape[1]): - plt.plot(self.p[0,self.t[[0,1],itr]], self.p[1,self.t[[0,1],itr]], 'k-') - plt.plot(self.p[0,self.t[[1,2],itr]], self.p[1,self.t[[1,2],itr]], 'k-') - plt.plot(self.p[0,self.t[[0,2],itr]], self.p[1,self.t[[0,2],itr]], 'k-') - return fig - def plot(self, z: ndarray, smooth: Optional[bool] = False, @@ -459,3 +431,10 @@ def _uniform_refine(self): def mapping(self): return MappingAffine(self) + + def element_finder(self): + from matplotlib.tri import Triangulation + + return Triangulation(self.p[0, :], + self.p[1, :], + self.t.T).get_trifinder()