Skip to content

Commit

Permalink
Merge pull request #1418 from WPengXiang/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
AlbertZyy authored Dec 26, 2024
2 parents f7f57cd + efd1d95 commit b678234
Show file tree
Hide file tree
Showing 12 changed files with 250 additions and 127 deletions.
49 changes: 17 additions & 32 deletions app/tssim/3d_NS/main_ipcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from fealpy.solver import spsolve

output = './'
T = 10
T = 1
nt = 500
n = 16

Expand All @@ -31,58 +31,43 @@
space = LagrangeFESpace(mesh, p=2)
uspace = TensorFunctionSpace(space, (2,-1))

solver = NSFEMSolver(pde, mesh, pspace, space, dt, q=5)
solver = NSFEMSolver(pde, mesh, pspace, uspace, dt, q=5)

ugdof = uspace.number_of_global_dofs()
pgdof = pspace.number_of_global_dofs()

u0 = uspace.function()
us = uspace.function()
u1 = uspace.function()
p0 = pspace.function()
p1 = pspace.function()
'''
ipoint = space.interpolation_points()
import matplotlib.pylab as plt
fig = plt.figure()
axes = fig.gca()
mesh.add_plot(axes)
#mesh.find_edge(axes,fontsize=20,showindex=True)
mesh.find_node(axes,node=ipoint,fontsize=20,showindex=True)
plt.show()
'''

fname = output + 'test_'+ str(0).zfill(10) + '.vtu'
mesh.nodedata['u'] = u1.reshape(2,-1).T
mesh.nodedata['p'] = p1
mesh.to_vtk(fname=fname)
'''

BC = DirichletBC(space=uspace,
gd=pde.velocity,
threshold=pde.is_u_boundary,
method='interp')
'''
BForm = solver.IPCS_BForm_0(None)
A = BForm.assembly()
print(A.to_dense())
#A = BC.apply_matrix(A)
print(bm.sum(bm.abs(A.to_dense())))



exit()
LForm = solver.Ossen_LForm()

BForm0 = solver.IPCS_BForm_0(None)
LForm0 = solver.IPCS_LForm_0()
A0 = BForm0.assembly()

for i in range(10):
for i in range(1):
t = timeline.next_time_level()
print(f"第{i+1}步")
print("time=", t)

solver.NS_update(u1)
A = BForm.assembly()
b = LForm.assembly()
A,b = BC.apply(A,b)
solver.update_ipcs_0(u0, p0)
print(bm.sum(bm.abs(A0.to_dense())))
b0 = LForm0.assembly()
A0,b0 = BC.apply(A0,b0)
print(bm.sum(bm.abs(b0)))

x = spsolve(A, b, 'mumps')
u1[:] = x[:ugdof]
p1[:] = x[ugdof:]
us[:] = spsolve(A0, b0, 'mumps')

fname = output + 'test_'+ str(i+1).zfill(10) + '.vtu'
mesh.nodedata['u'] = u1.reshape(2,-1).T
Expand Down
15 changes: 2 additions & 13 deletions app/tssim/NS-CH-GNBC/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
#bm.set_default_device('cuda')

output = './'
h = 1/256
#h = 1/256
h = 1/10
T = 2
nt = int(T/(0.1*h))

Expand All @@ -47,20 +48,8 @@
space = LagrangeFESpace(mesh, p=2)
uspace = TensorFunctionSpace(space, (2,-1))

'''
ipoint = space.interpolation_points()
import matplotlib.pylab as plt
fig = plt.figure()
axes = fig.gca()
mesh.add_plot(axes)
#mesh.find_edge(axes,fontsize=20,showindex=True)
mesh.find_node(axes,node=ipoint,fontsize=20,showindex=True)
plt.show()
'''

solver = Solver(pde, mesh, pspace, phispace, uspace, dt, q=5)


u0 = uspace.function()
u1 = uspace.function()
u2 = uspace.function()
Expand Down
15 changes: 8 additions & 7 deletions app/tssim/NS-CH-GNBC/pde.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,21 @@
from fealpy.mesh import TriangleMesh

class CouetteFlow:
'''
@brief: CouetteFlow
'''

def __init__(self, eps=1e-10, h=1/256):
self.eps = eps

## init the parameter
self.R = 5.0 ##dimensionless
self.l_s = 0.0025 ##dimensionless slip length
self.L_s = self.l_s / 100
self.L_s = self.l_s

self.epsilon = 0.004 ## the thickness of interface
self.L_d = 0.0005 ##phenomenological mobility cofficient
self.lam = 12.0 ##dimensionless
self.V_s = 200.0 ##dimensionless
self.s = 2.5 ##stablilizing parameter
#self.theta_s = bm.array(bm.pi/2)
self.theta_s = bm.array(77.6/180 * bm.pi)
self.theta_s = bm.array(bm.pi/2)
#self.theta_s = bm.array(77.6/180 * bm.pi)
self.h = h

def mesh(self):
Expand All @@ -56,6 +52,11 @@ def is_uy_Dirichlet(self,p):
return (bm.abs(p[..., 1] - 0.125) < self.eps) | \
(bm.abs(p[..., 1] + 0.125) < self.eps)

@cartesian
def is_slip_boundary(self,p):
return self.is_wall_boundary(p)
#return bm.logical_not(self.is_wall_boundary(p)))

@cartesian
def init_phi(self,p):
x = p[..., 0]
Expand Down
7 changes: 5 additions & 2 deletions app/tssim/NS-CH-GNBC/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, pde, mesh, pspace, phispace, uspace, dt, q=5):
self.pde = pde
self.dt = dt
self.q = q

def CH_BForm(self):
phispace = self.phispace
dt = self.dt
Expand All @@ -54,7 +54,8 @@ def CH_BForm(self):
A10 = BilinearForm(phispace)
A10.add_integrator(ScalarDiffusionIntegrator(coef=-epsilon, q=q))
A10.add_integrator(ScalarMassIntegrator(coef=-s/epsilon, q=q))
A10.add_integrator(BoundaryFaceMassIntegrator(coef=-3/(2*dt*V_s), q=q, threshold=self.pde.is_wall_boundary))
A10.add_integrator(BoundaryFaceMassIntegrator(coef=-3/(2*dt*V_s),
q=q, threshold=self.pde.is_slip_boundary))

A11 = BilinearForm(phispace)
A11.add_integrator(ScalarMassIntegrator(coef=1, q=q))
Expand Down Expand Up @@ -191,6 +192,7 @@ def NS_update(self, u_0, u_1, mu_2, phi_2, phi_1):
lam = self.pde.lam
epsilon = self.pde.epsilon
normal = self.mesh.edge_unit_normal()
normal[..., 1] = 1
tangent = self.mesh.edge_unit_tangent()
tangent[..., 0] = 1

Expand All @@ -214,6 +216,7 @@ def u_SI_coef(bcs, index):

def u_BF_SI_coef(bcs, index):
L_phi = epsilon*bm.einsum('eld, ed -> el', phi_2.grad_value(bcs, index), normal[index,:])
print(phi_2.grad_value(bcs, index))
L_phi -= 2*(bm.sqrt(bm.array(2))/6)*bm.pi*bm.cos(theta_s)*bm.cos((bm.pi/2)*phi_2(bcs, index))
L_phi += (bm.sqrt(bm.array(2))/6)*bm.pi*bm.cos(theta_s)*bm.cos((bm.pi/2)*phi_1(bcs, index))

Expand Down
43 changes: 39 additions & 4 deletions fealpy/cfd/ns_fem_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
SourceIntegrator,
PressWorkIntegrator,
FluidBoundaryFrictionIntegrator,
FluidBoundaryFrictionIntegratorP)
ViscousWorkIntegrator)
from fealpy.fem import (BoundaryFaceMassIntegrator,
BoundaryFaceSourceIntegrator)

Expand Down Expand Up @@ -94,8 +94,43 @@ def IPCS_BForm_0(self, threshold):
q = self.q

Bform = BilinearForm(uspace)
M = ScalarMassIntegrator(coef=R , q=q)
S = ScalarDiffusionIntegrator(coef=R , q=q)
F = FluidBoundaryFrictionIntegrator(coef=1, q=5, threshold=threshold)
M = ScalarMassIntegrator(coef=R/dt, q=q)
F = FluidBoundaryFrictionIntegrator(coef=-1, q=q, threshold=threshold)
VW = ViscousWorkIntegrator(coef=2, q=q)
Bform.add_integrator(VW)
Bform.add_integrator(F)
Bform.add_integrator(M)
return Bform

def IPCS_LForm_0(self, pthreshold=None):
pspace = self.pspace
uspace = self.uspace
dt = self.dt
R = self.pde.R
q = self.q

Lform = LinearForm(uspace)
self.ipcs_lform_SI = SourceIntegrator(q=q)
self.ipcs_lform_SI.keep_data()
self.ipcs_lform_BSI = BoundaryFaceSourceIntegrator(q=q, threshold=pthreshold)
self.ipcs_lform_BSI.keep_data()
return Lform

def update_ipcs_0(self, u0, p0):
dt = self.dt
R = self.pde.R

def coef(bcs, index):
result = 1/dt*u0(bcs, index)
result += np.einsum('...j, ....,ij -> ...i', u0(bcs, index), u0.grad_value(bcs, index))
result += np.repeat(p0(bcs,index)[...,np.newaxis], 2, axis=-1)
return

def B_coef(bcs, index):
result = np.einsum('..ij, ....j->...ij', p(bcs, index), self.mesh.edge_unit_normal(bcs, index))
return

self.ipcs_lform_SI.source = coef
self.ipcs_lform_BSI.source = B_coef


2 changes: 1 addition & 1 deletion fealpy/fem/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from .curlcurl_integrator import CurlCurlIntegrator
from .nonlinear_elastic_integrator import NonlinearElasticIntegrator
from .div_integrator import DivIntegrator
from .viscous_work_integrator import ViscousWorkIntegrator
from .scalar_biharmonic_integrator import ScalarBiharmonicIntegrator


### Cell Source
from .cell_source_integrator import CellSourceIntegrator
SourceIntegrator = CellSourceIntegrator
Expand Down
59 changes: 1 addition & 58 deletions fealpy/fem/fluid_boundary_friction_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,64 +81,7 @@ def assembly(self, space: _FS):
mesh = getattr(space, 'mesh', None)
bcs, ws, phi, gphi, fm, index, n = self.fetch(space)
val = process_coef_func(coef, bcs=bcs, mesh=mesh, etype='face', index=index)
gphin = bm.einsum('ej, eqlj->eql', n, gphi)
gphin = bm.einsum('e...i, eql...ij->eql...j', n, gphi)
result = bilinear_integral(phi, gphin, ws, fm, val, batched=self.batched)
return result

class FluidBoundaryFrictionIntegratorP(LinearInt, OpInt, FaceInt):
def __init__(self, coef: Optional[CoefLike]=None, q: Optional[int]=None, *,
threshold: Optional[Threshold]=None,
batched: bool=False, mesh):
super().__init__()
self.coef = coef
self.q = q
self.threshold = threshold
self.batched = batched
self.mesh = mesh

def make_index(self, space: _FS) -> TensorLike:
threshold = self.threshold
mesh = self.mesh
if isinstance(threshold, TensorLike):
index = threshold
else:
index = mesh.boundary_face_index()
if callable(threshold):
bc = mesh.entity_barycenter('face', index=index)
index = index[threshold(bc)]
return index

@enable_cache
def to_global_dof(self, space: _FS) -> TensorLike:
index = self.make_index(space)
return space.face_to_dof(index=index)

@enable_cache
def fetch(self, space: _FS):
space0 = space[0]
space1 = space[1]
index = self.make_index(space)
mesh = self.mesh

if not isinstance(mesh, HomogeneousMesh):
raise RuntimeError("The ScalarRobinBCIntegrator only support spaces on"
f"homogeneous meshes, but {type(mesh).__name__} is"
"not a subclass of HomoMesh.")

facemeasure = mesh.entity_measure('face', index=index)
q = space.p+3 if self.q is None else self.q
qf = mesh.quadrature_formula(q, 'face')
bcs, ws = qf.get_quadrature_points_and_weights()
phi1 = space[1].face_basis(bcs, index)
phi0 = space[0].face_basis(bcs, index)
n = mesh.edge_normal(index)
return bcs, ws, phi0, phi1, facemeasure, index, n

def assembly(self, space: _FS):
coef = self.coef
mesh = getattr(space, 'mesh', None)
bcs, ws, phi0, phi1, fm, index, n = self.fetch(space)
val = process_coef_func(coef, bcs=bcs, mesh=mesh, etype='face', index=index)
phi0n = bm.einsum('ej, e...j->e...j', n, phi0)
result = bilinear_integral(phi1, phi0n, ws, fm, val, batched=self.batched)
return result
33 changes: 25 additions & 8 deletions fealpy/fem/scalar_diffusion_integrator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, Literal
from ..mesh.mesh_base import SimplexMesh

from ..backend import backend_manager as bm
from ..typing import TensorLike, Index, _S
Expand Down Expand Up @@ -62,7 +63,6 @@ def assembly(self, space: _FS, /, indices=None) -> TensorLike:
index = self.entity_selection(indices)
coef = process_coef_func(coef, bcs=bcs, mesh=mesh, etype='cell', index=index)
gphi = self.fetch_gphix(space, indices)

return bilinear_integral(gphi, gphi, ws, cm, coef, batched=self.batched)

@assemblymethod('fast')
Expand All @@ -72,13 +72,30 @@ def fast_assembly(self, space: _FS, /, indices=None) -> TensorLike:
TODO: 加入 assert
"""
mesh = space.mesh
_, ws = self.fetch_qf(space)
gphi = self.fetch_gphiu(space, indices)
M = bm.einsum('q, qik, qjl -> ijkl', ws, gphi, gphi)
cm = self.fetch_measure(space, indices)
glambda = mesh.grad_lambda(index=self.entity_selection(indices))
A = bm.einsum('ijkl, ckm, clm, c -> cij', M, glambda, glambda, cm)
return A
if isinstance(mesh, SimplexMesh):

_, ws = self.fetch_qf(space)
gphi = self.fetch_gphiu(space, indices)
M = bm.einsum('q, qik, qjl -> ijkl', ws, gphi, gphi)
cm = self.fetch_measure(space, indices)
glambda = mesh.grad_lambda(index=self.entity_selection(indices))
result = bm.einsum('ijkl, ckm, clm, c -> cij', M, glambda, glambda, cm)
else:
coef = self.coef
mesh = space.mesh
index = self.entity_selection(indices)
cm = self.fetch_measure(space, indices)
bcs,ws = self.fetch_qf(space)
coef = process_coef_func(coef, bcs=bcs, mesh=mesh, etype='cell', index=index)

gphiu = self.fetch_gphiu(space, indices)
M = bm.einsum('qim, qjn, q -> qijmn', gphiu, gphiu, ws)
J = mesh.jacobi_matrix(bcs, index)
G = mesh.first_fundamental_form(J)
G = bm.linalg.inv(G)
JG = bm.einsum("cqkm, cqmn -> cqkn", J, G)
result = bm.einsum('cqkn, qijmn, cqkm, c -> cij', JG, M, JG, cm) # (NC, NQ, ldof, GD)
return result

@assemblymethod('nonlinear')
def nonlinear_assembly(self, space: _FS, /, indices=None) -> TensorLike:
Expand Down
Loading

0 comments on commit b678234

Please sign in to comment.