Skip to content

Commit

Permalink
update centroid calculation to use center of mass
Browse files Browse the repository at this point in the history
  • Loading branch information
MuellerSeb committed Jul 11, 2019
1 parent e3428ed commit b0708a6
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 34 deletions.
7 changes: 6 additions & 1 deletion ogs5py/fileclasses/msh/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
rotation_matrix,
uncomment,
volume,
centroid,
)


Expand Down Expand Up @@ -1033,7 +1034,11 @@ def get_centroids(mesh):
if elem not in mesh_i["elements"]:
continue
points = mesh_i["nodes"][mesh_i["elements"][elem]]
out[elem] = np.mean(points, axis=1)
# node number needs to be first for "centroid()"
points = np.swapaxes(points, 0, 1)
out[elem] = centroid(elem, points)
# this was just the centroid of the element nodes
# out[elem] = np.mean(points, axis=1)
result.append(out)

if single:
Expand Down
181 changes: 148 additions & 33 deletions ogs5py/tools/tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
tools for the ogs5py package
Tools for the ogs5py package.
.. currentmodule:: ogs5py.tools.tools
Expand Down Expand Up @@ -107,6 +107,7 @@ def flush(self):
sys.stdout.flush()

def __del__(self):
"""Close and delete."""
if not self._closed:
self.close()

Expand Down Expand Up @@ -135,6 +136,7 @@ def search_mkey(fin):
def uncomment(line):
"""
Remove OGS comments from a given line of an OGS file.
Comments are indicated by ";". The line is then splitted by whitespaces.
Parameters
Expand All @@ -147,7 +149,7 @@ def uncomment(line):

def is_key(sline):
"""
Check if the given splitted line is an OGS key
Check if the given splitted line is an OGS key.
Parameters
----------
Expand All @@ -159,7 +161,7 @@ def is_key(sline):

def is_mkey(sline):
"""
Check if the given splitted line is a main key
Check if the given splitted line is a main key.
Parameters
----------
Expand All @@ -171,7 +173,7 @@ def is_mkey(sline):

def is_skey(sline):
"""
Check if the given splitted line is a sub key
Check if the given splitted line is a sub key.
Parameters
----------
Expand All @@ -183,7 +185,7 @@ def is_skey(sline):

def get_key(sline):
"""
Get the key of a splitted line if there is any. Else return ""
Get the key of a splitted line if there is any, else return "".
Parameters
----------
Expand Down Expand Up @@ -233,7 +235,7 @@ def find_key_in_list(key, key_list):

def format_dict(dict_in):
"""
format the dictionary to use upper-case keys
Format the dictionary to use upper-case keys.
Parameters
----------
Expand All @@ -252,7 +254,7 @@ def format_dict(dict_in):

def guess_type(string):
"""
guess the type of a value given as string and return it accordingly
Guess the type of a value given as string and return it accordingly.
Parameters
----------
Expand All @@ -270,7 +272,7 @@ def guess_type(string):

def format_content_line(content):
"""
format a line of content to be a list of values
Format a line of content to be a list of values.
Parameters
----------
Expand All @@ -288,7 +290,7 @@ def format_content_line(content):

def format_content(content):
"""
format the content to be added to a 2D linewise array
Format the content to be added to a 2D linewise array.
Parameters
----------
Expand Down Expand Up @@ -329,7 +331,7 @@ def format_content(content):

def search_task_id(task_root, search_ext=None):
"""
Search for OGS model names in the given path
Search for OGS model names in the given path.
Parameters
----------
Expand Down Expand Up @@ -360,8 +362,9 @@ def search_task_id(task_root, search_ext=None):

def split_file_path(path, abs_path=False):
"""
decompose a path to a file into the dir-path, the basename
and the file-extension
Decompose a path to a file.
Decompose into the dir-path, the basename and the file-extension.
Parameters
----------
Expand All @@ -382,7 +385,7 @@ def split_file_path(path, abs_path=False):

def is_str_array(array):
"""
A routine to check if an array contains strings
A routine to check if an array contains strings.
Parameters
----------
Expand Down Expand Up @@ -464,6 +467,7 @@ def shift_points(points, vector):
def transform_points(points, xyz_func, **kwargs):
r"""
Transform points with a given function "xyz_func".
kwargs will be forwarded to "xyz_func".
Parameters
Expand Down Expand Up @@ -494,8 +498,10 @@ def hull_deform(
direction="z",
):
"""
Providing a transformation function to deform a given mesh in a given
direction by self defined hull-functions ``z = func(x, y)``.
Providing a transformation function to deform a given mesh.
Transformation is in a given
direction by a self defined hull-functions ``z = func(x, y)``.
Could be used with ``transform_mesh`` and ``transform_points``.
Parameters
Expand Down Expand Up @@ -523,7 +529,6 @@ def hull_deform(
x_out, y_out, z_out : ndarray
transformed arrays
"""

if direction == "x":
x1_in = y_in
x2_in = z_in
Expand All @@ -546,7 +551,7 @@ def hull_deform(
if isinstance(func_top, (float, int)):

def func_top_redef(x_in, __):
"""redefining func_top for constant value"""
"""Redefining func_top for constant value."""
return float(func_top) * np.ones_like(x_in)

func_t = func_top_redef
Expand All @@ -556,7 +561,7 @@ def func_top_redef(x_in, __):
if isinstance(func_bot, (float, int)):

def func_bot_redef(x_in, __):
"""redefining func_bot for constant value"""
"""Redefining func_bot for constant value."""
return float(func_bot) * np.ones_like(x_in)

func_b = func_bot_redef
Expand All @@ -570,10 +575,10 @@ def func_bot_redef(x_in, __):

if direction == "x":
return x3_out, x1_in, x2_in
if direction == "y":
elif direction == "y":
return x1_in, x3_out, x2_in
if direction == "z":
return x1_in, x2_in, x3_out
# elif direction == "z":
return x1_in, x2_in, x3_out


#####################
Expand All @@ -583,8 +588,9 @@ def func_bot_redef(x_in, __):

def rotation_matrix(vector, angle):
"""
Create a rotation matrix for rotation around a given vector with a given
angle.
Create a rotation matrix.
For rotation around a given vector with a given angle.
Parameters
----------
Expand All @@ -611,8 +617,9 @@ def rotation_matrix(vector, angle):

def replace(arr, inval, outval):
"""
replace certain values of 'arr' defined in 'inval' with values defined
in 'outval'
Replace values of 'arr'.
Replace values defined in 'inval' with values defined in 'outval'.
Parameters
----------
Expand Down Expand Up @@ -645,7 +652,7 @@ def replace(arr, inval, outval):

def unique_rows(data, decimals=4, fast=True):
"""
unique made row-data with respect to given precision
Unique made row-data with respect to given precision.
this is constructed to work best if point-pairs appear.
The output is sorted like the input data.
Expand Down Expand Up @@ -728,7 +735,7 @@ def unique_rows(data, decimals=4, fast=True):

def by_id(array, ids=None):
"""
Return a flattend array side-by-side with the array-element ids
Return a flattend array side-by-side with the array-element ids.
Parameters
----------
Expand All @@ -754,7 +761,8 @@ def by_id(array, ids=None):

def unique_rows_old(data, decimals=4):
"""
returns unique made data with respect to given precision in "decimals"
Returns unique made data with respect to given precision in "decimals".
The output is sorted like the input data.
data needs to be 2D
Expand Down Expand Up @@ -807,14 +815,14 @@ def unique_rows_old(data, decimals=4):
return out[sort], ixsort, ixrsort


####################
# volume functions #
####################
#################################
# volume and centroid functions #
#################################


def volume(typ, *pnt):
"""
Volume of a OGS5 Meshelement
Volume of a OGS5 Meshelement.
Parameters
----------
Expand Down Expand Up @@ -867,7 +875,7 @@ def volume(typ, *pnt):
if typ == "hex":
return _vol_hex(*pnt)

print("unknown volume typ: " + str(typ))
print("unknown element typ: " + str(typ))
return 0.0


Expand Down Expand Up @@ -908,3 +916,110 @@ def _vol_hex(*pnt):
return _vol_pris(
pnt[0], pnt[1], pnt[2], pnt[4], pnt[5], pnt[6]
) + _vol_pris(pnt[0], pnt[2], pnt[3], pnt[4], pnt[5], pnt[6], pnt[7])


def centroid(typ, *pnt):
"""
Centroid of a OGS5 Meshelement.
Parameters
----------
typ : string
OGS5 Meshelement type. Should be one of the following:
* "line" : 1D element with 2 nodes
* "tri" : 2D element with 3 nodes
* "quad" : 2D element with 4 nodes
* "tet" : 3D element with 4 nodes
* "pyra" : 3D element with 5 nodes
* "pris" : 3D element with 6 nodes
* "hex" : 3D element with 8 nodes
*pnt : Node Choordinates ``pnt = (x_0, x_1, ...)``
List of points defining the Meshelement. A point is given as an
(x,y,z) tuple and for each point, there can be a stack of points, if
the volume should be calculated for multiple elements of the same type.
Returns
-------
Volume : ndarray
Array containing the Centroids of the give elements.
"""
# if the pntinates are stacked, divide them
if len(pnt) == 1:
np_pnt = np.array(pnt[0], ndmin=3, dtype=float)
pnt = []
for i in range(np_pnt.shape[0]):
pnt.append(np_pnt[i])
# else assure we got numpy arrays as lists of points (x,y,z)
else:
pnt_list = list(pnt)
pnt = []
for i in range(len(pnt_list)):
pnt.append(np.array(pnt_list[i], ndmin=2, dtype=float))

if typ == "line":
return _cent_line(*pnt)
if typ == "tri":
return _cent_tri(*pnt)
if typ == "quad":
return _cent_quad(*pnt)
if typ == "tet":
return _cent_tet(*pnt)
if typ == "pyra":
return _cent_pyra(*pnt)
if typ == "pris":
return _cent_pris(*pnt)
if typ == "hex":
return _cent_hex(*pnt)

print("unknown element typ: " + str(typ))
return 0.0


def _cent_line(*pnt):
return (pnt[0] + pnt[1]) / 2.0


def _cent_tri(*pnt):
return (pnt[0] + pnt[1] + pnt[2]) / 3.0


def _cent_quad(*pnt):
return (
_cent_tri(pnt[0], pnt[1], pnt[2])
* _vol_tri(pnt[0], pnt[1], pnt[2])
+ _cent_tri(pnt[2], pnt[3], pnt[0])
* _vol_tri(pnt[2], pnt[3], pnt[0])
) / _vol_quad(*pnt)


def _cent_tet(*pnt):
return (pnt[0] + pnt[1] + pnt[2] + pnt[3]) / 4.0


def _cent_pyra(*pnt):
return (
_cent_tet(pnt[0], pnt[1], pnt[2], pnt[4])
* _vol_tet(pnt[0], pnt[1], pnt[2], pnt[4])
+ _cent_tet(pnt[0], pnt[2], pnt[3], pnt[4])
* _vol_tet(pnt[0], pnt[2], pnt[3], pnt[4])
) / _vol_pyra(*pnt)


def _cent_pris(*pnt):
return (
_cent_pyra(pnt[0], pnt[3], pnt[4], pnt[1], pnt[2])
* _vol_pyra(pnt[0], pnt[3], pnt[4], pnt[1], pnt[2])
+ _cent_tet(pnt[3], pnt[4], pnt[5], pnt[2])
* _vol_tet(pnt[3], pnt[4], pnt[5], pnt[2])
) / _vol_pris(*pnt)


def _cent_hex(*pnt):
return (
_cent_pris(pnt[0], pnt[1], pnt[2], pnt[4], pnt[5], pnt[6])
* _vol_pris(pnt[0], pnt[1], pnt[2], pnt[4], pnt[5], pnt[6])
+ _cent_pris(pnt[0], pnt[2], pnt[3], pnt[4], pnt[5], pnt[6], pnt[7])
* _vol_pris(pnt[0], pnt[2], pnt[3], pnt[4], pnt[5], pnt[6], pnt[7])
) / _vol_hex(*pnt)

0 comments on commit b0708a6

Please sign in to comment.