Skip to content

Commit

Permalink
⚡ 单独管理各分量,规避数组大小上限问题 (taichi-dev/taichi#6758)
Browse files Browse the repository at this point in the history
✨ 新的taichi函数封装与with-able FieldBuilder
  • Loading branch information
yanang007 committed Jan 1, 2023
1 parent c56ef8b commit 59ffb9e
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 34 deletions.
109 changes: 75 additions & 34 deletions metalpy/scab/demag/demagnetization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from discretize.utils import mkvc

from ..utils.misc import Field
from ...utils.taichi import ti_kernel, ti_ndarray
from ...utils.taichi import ti_kernel, ti_field, ti_FieldsBuilder


class Demagnetization:
Expand Down Expand Up @@ -62,31 +62,38 @@ def dpred(self, model):
nC = self.Xn.shape[0]
nObs = self.receiver_locations.shape[0]
H0 = self.source_field.unit_vector
H0 = np.repeat(H0, nC).ravel()
H0 = np.tile(H0[None, :], nC).ravel()

base_cell_sizes = np.r_[
self.mesh.h[0].min(),
self.mesh.h[1].min(),
self.mesh.h[2].min(),
]

A = ti_ndarray(ti.f64, (3 * nObs, 3 * nC))
# A = I - X @ T, where T is the forward kernel, T @ mv = Bv
# mv and Bv is channel first
# mv = [Mx1, My1, Mz1, ... Mxn, Myn, Mzn]
# Bv = [Bx1, By1, Bz1, ... Bxn, Byn, Bzn]
kernel_matrix_forward(self.receiver_locations, self.Xn, self.Yn, self.Zn, base_cell_sizes, model, A)
A = A.to_numpy()
with ti_FieldsBuilder() as builder:
Tmat = [
ti_field(ti.f64)
for _ in range(3 * 3)
] # Txx, Txy, Txz, Tyx, Tyy, Tyz, Tzx, Tzy, Tzz

X = np.repeat(model[None, :], 3, axis=0).ravel()
builder.dense(ti.ij, (nObs, nC)).place(*Tmat)
builder.finalize()

kernel_matrix_forward(self.receiver_locations, self.Xn, self.Yn, self.Zn, base_cell_sizes, model, *Tmat)

A = np.empty((3 * nObs, 3 * nC), dtype=np.float64)
for i in range(3):
for j in range(3):
tensor_to_ext_arr(Tmat[i * 3 + j], A, i, 3, j, 3)

X = np.tile(model, 3).ravel()
X = sp.diags(X)

b = X @ H0

m, info = pyamg.krylov.bicgstab(A, b)

# assert abs(X @ (H0 + T @ m) - m).mean() < 1e-3, 'fucked up'
return m.reshape(3, -1).T
return m.reshape(-1, 3)


@ti_kernel
Expand All @@ -97,15 +104,39 @@ def kernel_matrix_forward(
zn: ti.types.ndarray(),
base_cell_sizes: ti.types.ndarray(),
susc_model: ti.types.ndarray(),
ret: ti.types.ndarray()
Txx: ti.types.template(),
Txy: ti.types.template(),
Txz: ti.types.template(),
Tyx: ti.types.template(),
Tyy: ti.types.template(),
Tyz: ti.types.template(),
Tzx: ti.types.template(),
Tzy: ti.types.template(),
Tzz: ti.types.template(),
):
# calculates A = I - X @ T, where T is the forward kernel, s.t. T @ m_v = B_v
# m_v and B_v are both channel first (Array of Structure in taichi)
# m_v = [Mx1, My1, Mz1, ... Mxn, Myn, Mzn]
# B_v = [Bx1, By1, Bz1, ... Bxn, Byn, Bzn]
# | --- nC --- |
# ---------------------------------------------------------- ---
# | Txx, Txy, Txz, | Txx, Txy, Txz, | ... | Txx, Txy, Txz, | |
# | Tyx, Tyy, Tyz, | Tyx, Tyy, Tyz, | ... | Tyx, Tyy, Tyz, | |
# | Tzx, Tzy, Tzz, | Tzx, Tzy, Tzz, | ... | Tzx, Tzy, Tzz, | |
# ----------------------------------------------------------
# T = | ... | ... | ... | ... | nObs
# ----------------------------------------------------------
# | Txx, Txy, Txz, | Txx, Txy, Txz, | ... | Txx, Txy, Txz, | |
# | Tyx, Tyy, Tyz, | Tyx, Tyy, Tyz, | ... | Tyx, Tyy, Tyz, | |
# | Tzx, Tzy, Tzz, | Tzx, Tzy, Tzz, | ... | Tzx, Tzy, Tzz, | |
# ---------------------------------------------------------- ---

# TODO: This should probably be converted to C
tol1 = 1e-10 # Tolerance 1 for numerical stability over nodes and edges
tol2 = 1e-4 # Tolerance 2 for numerical stability over nodes and edges

# number of cells in mesh
nC = xn.shape[0]
nObs = receiver_locations.shape[0]

# base cell dimensions
min_hx = base_cell_sizes[0]
Expand Down Expand Up @@ -219,7 +250,7 @@ def kernel_matrix_forward(
arg39 = dy1 + r8
arg40 = dz1 + r8

bx_x = (
txx = (
-2 * ti.atan2(dx1, arg1 + tol1)
- -2 * ti.atan2(dx2, arg6 + tol1)
+ -2 * ti.atan2(dx2, arg11 + tol1)
Expand All @@ -230,7 +261,7 @@ def kernel_matrix_forward(
- -2 * ti.atan2(dx2, arg36 + tol1)
) / -4 / ti.math.pi

bx_y = (
txy = (
ti.log(arg5)
- ti.log(arg10)
+ ti.log(arg15)
Expand All @@ -241,14 +272,14 @@ def kernel_matrix_forward(
- ti.log(arg40)
) / -4 / ti.math.pi

bx_z = (
txz = (
ti.log(arg4) - ti.log(arg9)
+ ti.log(arg14) - ti.log(arg19)
+ ti.log(arg24) - ti.log(arg29)
+ ti.log(arg34) - ti.log(arg39)
) / -4 / ti.math.pi

by_x = (
tyx = (
ti.log(arg5)
- ti.log(arg10)
+ ti.log(arg15)
Expand All @@ -259,7 +290,7 @@ def kernel_matrix_forward(
- ti.log(arg40)
) / -4 / ti.math.pi

by_y = (
tyy = (
-2 * ti.atan2(dy2, arg2 + tol1)
- -2 * ti.atan2(dy2, arg7 + tol1)
+ -2 * ti.atan2(dy2, arg12 + tol1)
Expand All @@ -270,14 +301,14 @@ def kernel_matrix_forward(
- -2 * ti.atan2(dy1, arg37 + tol1)
) / -4 / ti.math.pi

by_z = (
tyz = (
ti.log(arg3) - ti.log(arg8)
+ ti.log(arg13) - ti.log(arg18)
+ ti.log(arg23) - ti.log(arg28)
+ ti.log(arg33) - ti.log(arg38)
) / -4 / ti.math.pi

bz_x = (
tzx = (
ti.log(arg4)
- ti.log(arg9)
+ ti.log(arg14)
Expand All @@ -288,14 +319,14 @@ def kernel_matrix_forward(
- ti.log(arg39)
) / -4 / ti.math.pi

bz_y = (
tzy = (
ti.log(arg3) - ti.log(arg8)
+ ti.log(arg13) - ti.log(arg18)
+ ti.log(arg23) - ti.log(arg28)
+ ti.log(arg33) - ti.log(arg38)
) / -4 / ti.math.pi

bz_z = (
tzz = (
-2 * ti.atan2(dz2, arg1_ + tol1)
- -2 * ti.atan2(dz2, arg6_ + tol1)
+ -2 * ti.atan2(dz1, arg11_ + tol1)
Expand All @@ -306,17 +337,27 @@ def kernel_matrix_forward(
- -2 * ti.atan2(dz1, arg36_ + tol1)
) / -4 / ti.math.pi

ret[iobs, icell] = -susc_model[icell] * bx_x
ret[iobs, icell + nC] = -susc_model[icell] * bx_y
ret[iobs, icell + 2 * nC] = -susc_model[icell] * bx_z
neg_sus = -susc_model[icell]

Txx[iobs, icell] = neg_sus * txx
Txy[iobs, icell] = neg_sus * txy
Txz[iobs, icell] = neg_sus * txz
Tyx[iobs, icell] = neg_sus * tyx
Tyy[iobs, icell] = neg_sus * tyy
Tyz[iobs, icell] = neg_sus * tyz
Tzx[iobs, icell] = neg_sus * tzx
Tzy[iobs, icell] = neg_sus * tzy
Tzz[iobs, icell] = neg_sus * tzz

ret[iobs + nObs, icell] = -susc_model[icell] * by_x
ret[iobs + nObs, icell + nC] = -susc_model[icell] * by_y
ret[iobs + nObs, icell + 2 * nC] = -susc_model[icell] * by_z
for i in range(nC):
Txx[i, i] += 1
Tyy[i, i] += 1
Tzz[i, i] += 1

ret[iobs + 2 * nObs, icell] = -susc_model[icell] * bz_x
ret[iobs + 2 * nObs, icell + nC] = -susc_model[icell] * bz_y
ret[iobs + 2 * nObs, icell + 2 * nC] = -susc_model[icell] * bz_z

for i in range(3 * nC):
ret[i, i] += 1
@ti_kernel
def tensor_to_ext_arr(tensor: ti.types.template(), arr: ti.types.ndarray(),
x0: ti.types.template(), xstride: ti.types.template(),
y0: ti.types.template(), ystride: ti.types.template()):
for I in ti.grouped(tensor):
arr[I[0] * xstride + x0, I[1] * ystride + y0] = tensor[I]
43 changes: 43 additions & 0 deletions metalpy/utils/taichi.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,46 @@ def lazy_evaluator_wrapper(*args, **kwargs):
def ti_ndarray(dtype, shape):
ti_init_once()
return ti.ndarray(dtype, shape)


def ti_root():
ti_init_once()
return ti.root


def ti_field(dtype,
shape=None,
order=None,
name="",
offset=None,
needs_grad=False,
needs_dual=False):
ti_init_once()
return ti.field(dtype, shape, order, name, offset, needs_grad, needs_dual)


class WrappedFieldsBuilder:
def __init__(self):
self.fields_builder = ti.FieldsBuilder()
self.snode_tree = None

def finalize(self, raise_warning=True):
self.snode_tree = self.fields_builder.finalize(raise_warning)
return self.snode_tree

def destroy(self):
self.snode_tree.destroy()

def __getattr__(self, name):
return getattr(self.fields_builder, name)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.destroy()


def ti_FieldsBuilder():
ti_init_once()
return WrappedFieldsBuilder()

0 comments on commit 59ffb9e

Please sign in to comment.