Skip to content

Commit

Permalink
Added InteriorBasis.interpolator. Removed MeshTri.{interpolator,const…
Browse files Browse the repository at this point in the history
…_interpolator,draw_debug}. Changed slicing and caching of Vandermonde matrix in ElementH2.gbasis.
  • Loading branch information
kinnala committed Oct 8, 2018
1 parent 5e52f0a commit 466049c
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 41 deletions.
22 changes: 20 additions & 2 deletions skfem/assembly/global_basis/interior_basis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

import numpy as np

from typing import Optional, Callable

from numpy import ndarray

from skfem.quadrature import get_quadrature
Expand Down Expand Up @@ -122,3 +123,20 @@ def refinterp(self, interp: ndarray, Nrefs: Optional[int] = 1):
M = meshclass(p, t, validate=False)

return M, w.flatten()

def interpolator(self, y: ndarray) -> Callable[[ndarray], ndarray]:
"""Return a function handle, which can be used for finding
pointwise values of the given solution vector."""

finder = self.mesh.element_finder()

def interpfun(x):
tris = finder(*x)
pts = self.mapping.invF(x[:, :, np.newaxis], tind=tris)
w = np.zeros(x.shape[1])
for k in range(self.Nbfun):
phi = self.elem.gbasis(self.mapping, pts, k, tind=tris)
w += y[self.element_dofs[k, tris]]*phi[0].flatten()
return w

return interpfun
66 changes: 56 additions & 10 deletions skfem/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

import numpy as np

from typing import Tuple, Union, List
from typing import Optional, Tuple, Union, List

from numpy import ndarray

class Element():
nodal_dofs: int = 0
Expand All @@ -35,6 +37,39 @@ def orient(self, mapping, i, tind=None):
else:
return 1 + 0*tind

def gbasis(self,
mapping,
X: ndarray,
i: int,
tind: Optional[ndarray] = None) -> Union[Tuple[ndarray, ndarray],
Tuple[ndarray, ndarray, ndarray]]:
"""Evaluate the global basis functions, given local points X.
The global points - at which the global basis is evaluated at -
are defined through x = F(X), where F corresponds to the given mapping.
Parameters
----------
mapping
Local-to-global mapping, an object of type :class:`~skfem.mapping.Mapping`.
X
An array of local points. The following shapes are supported: (Ndim x Npoints)
and (Ndim x Nelems x Npoints), i.e. local points shared by all elements
or different local points in each element.
i
Only the i'th basis function is evaluated.
tind
Optionally, choose a subset of elements to evaluate the basis at.
Returns
-------
(u, du) or (u, du, ddu)
The number of return arguments depends on the length of self.order.
The shape of k'th return argument depends on the value of self.order[k].
"""
raise NotImplementedError("Element must implement gbasis.")


class ElementH1(Element):
order = (0, 1)
Expand Down Expand Up @@ -83,7 +118,9 @@ class ElementHdiv(Element):
order = (1, 0)

def orient(self, mapping, i, tind=None):
# TODO fix tind
if tind is not None:
# TODO fix
raise NotImplementedError("TODO: fix tind support in ElementHdiv")
return -1 + 2*(mapping.mesh.f2t[0, mapping.mesh.t2f[i, :]] \
== np.arange(mapping.mesh.t.shape[1]))

Expand All @@ -102,10 +139,13 @@ def lbasis(self, X, i):
class ElementHcurl(Element):
"""Note: only 3D support. Piola transformation
is different in 2D."""

order = (1, 1)

def orient(self, mapping, i, tind=None):
# TODO fix tind
if tind is not None:
# TODO fix
raise NotImplementedError("TODO: fix tind support in ElementHcurl")
t1 = [0, 1, 0, 0, 1, 2][i]
t2 = [1, 2, 2, 3, 3, 3][i]
return 1 - 2*(mapping.mesh.t[t1, :] > mapping.mesh.t[t2, :])
Expand All @@ -124,17 +164,23 @@ def lbasis(self, X, i):


class ElementH2(Element):
"""Elements defined implicitly through global degrees-of-freedom."""

order = (0, 1, 2)
V = None # For caching inverse Vandermonde matrix

def gbasis(self, mapping, X, i, tind=None):
if tind is None:
tind = np.arange(mapping.mesh.t.shape[1])
# initialize power basis
self._pbasis_init(self.maxdeg)
N = len(self._pbasis)

if self.V is None:
# construct Vandermonde matrix and invert it
self.V = np.linalg.inv(self._eval_dofs(mapping.mesh, tind=tind))
self.V = np.linalg.inv(self._eval_dofs(mapping.mesh))

V = self.V[tind]

x = mapping.F(X, tind=tind)
u = np.zeros(x[0].shape)
Expand All @@ -143,17 +189,17 @@ def gbasis(self, mapping, X, i, tind=None):

# loop over new basis
for itr in range(N):
u += self.V[:, itr, i][:, None]\
u += V[:, itr, i][:, None]\
* self._pbasis[itr](x[0], x[1])
du[0] += self.V[:, itr, i][:, None]\
du[0] += V[:, itr, i][:, None]\
* self._pbasisdx[itr](x[0], x[1])
du[1] += self.V[:, itr, i][:,None]\
du[1] += V[:, itr, i][:,None]\
* self._pbasisdy[itr](x[0], x[1])
ddu[0, 0] += self.V[:, itr, i][:, None]\
ddu[0, 0] += V[:, itr, i][:, None]\
* self._pbasisdxx[itr](x[0], x[1])
ddu[0, 1] += self.V[:, itr, i][:, None]\
ddu[0, 1] += V[:, itr, i][:, None]\
* self._pbasisdxy[itr](x[0], x[1])
ddu[1, 1] += self.V[:, itr, i][:, None]\
ddu[1, 1] += V[:, itr, i][:, None]\
* self._pbasisdyy[itr](x[0], x[1])

# dxy = dyx
Expand Down
10 changes: 9 additions & 1 deletion skfem/mesh/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

from .submesh import Submesh

from typing import Dict, Optional, Tuple, Type, TypeVar, Union
from typing import Dict, Optional, Tuple,\
Type, TypeVar, Union,\
Callable
from numpy import ndarray

MeshType = TypeVar('MeshType', bound='Mesh')
Expand Down Expand Up @@ -295,3 +297,9 @@ def boundary_facets(self) -> ndarray:
def interior_facets(self) -> ndarray:
"""Return an array of interior facet indices."""
return np.nonzero(self.f2t[1, :] >= 0)[0]

def element_finder(self) -> Callable[[ndarray], ndarray]:
"""Return a function, which returns element
indices corresponding to the input points."""
raise NotImplementError("element_finder not implemented" +\
"for the given Mesh type.")
35 changes: 7 additions & 28 deletions skfem/mesh/mesh2d/mesh_tri.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,34 +260,6 @@ def _build_mappings(self, sort_t=True):
# second row to zero if repeated (i.e., on boundary)
self.f2t[1, np.nonzero(self.f2t[0, :] == self.f2t[1, :])[0]] = -1

def interpolator(self, x):
"""Return a function which interpolates values with P1 basis."""
triang = mtri.Triangulation(self.p[0, :], self.p[1, :], self.t.T)
interpf = mtri.LinearTriInterpolator(triang, x)
# contruct an interpolator handle
def handle(X, Y):
return interpf(X, Y).data
return handle

def const_interpolator(self, x):
"""Return a function which interpolates values with P0 basis."""
triang = mtri.Triangulation(self.p[0, :], self.p[1, :], self.t.T)
finder = triang.get_trifinder()
# construct an interpolator handle
def handle(X, Y):
return x[finder(X, Y)]
return handle

def draw_debug(self):
"""Draw without mesh.facets. For debugging self.draw()."""
fig = plt.figure()
plt.hold('on')
for itr in range(self.t.shape[1]):
plt.plot(self.p[0,self.t[[0,1],itr]], self.p[1,self.t[[0,1],itr]], 'k-')
plt.plot(self.p[0,self.t[[1,2],itr]], self.p[1,self.t[[1,2],itr]], 'k-')
plt.plot(self.p[0,self.t[[0,2],itr]], self.p[1,self.t[[0,2],itr]], 'k-')
return fig

def plot(self,
z: ndarray,
smooth: Optional[bool] = False,
Expand Down Expand Up @@ -459,3 +431,10 @@ def _uniform_refine(self):

def mapping(self):
return MappingAffine(self)

def element_finder(self):
from matplotlib.tri import Triangulation

return Triangulation(self.p[0, :],
self.p[1, :],
self.t.T).get_trifinder()

0 comments on commit 466049c

Please sign in to comment.