From 41cb5d5aea03dfee001fcee36c18af1eeb20edb5 Mon Sep 17 00:00:00 2001 From: LSchueler Date: Tue, 9 Nov 2021 19:51:44 +0100 Subject: [PATCH] Restore Cython code and make Rust impl. optional Now, the GSTools-Core is used, if the package can be imported, but it can also be switched off, by setting the global var. `gstools.config.USE_RUST=False` during the runtime. --- MANIFEST.in | 2 +- README.md | 2 +- docs/source/index.rst | 2 +- gstools/__init__.py | 1 + gstools/config.py | 14 ++ gstools/field/generator.py | 11 +- gstools/field/summator.pyx | 82 ++++++++ gstools/krige/base.py | 16 +- gstools/krige/krigesum.pyx | 64 ++++++ gstools/variogram/estimator.pyx | 346 ++++++++++++++++++++++++++++++++ gstools/variogram/variogram.py | 25 ++- setup.cfg | 3 +- setup.py | 167 ++++++++++++++- 13 files changed, 712 insertions(+), 23 deletions(-) create mode 100644 gstools/config.py create mode 100644 gstools/field/summator.pyx create mode 100644 gstools/krige/krigesum.pyx create mode 100644 gstools/variogram/estimator.pyx diff --git a/MANIFEST.in b/MANIFEST.in index 93362961..71c3bb1d 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,6 @@ prune * graft tests -recursive-include gstools *.py +recursive-include gstools *.py *.pyx recursive-exclude gstools *.c *.cpp include LICENSE README.md pyproject.toml setup.py setup.cfg exclude CHANGELOG.md CONTRIBUTING.md AUTHORS.md diff --git a/README.md b/README.md index 117ee2ea..c0488480 100644 --- a/README.md +++ b/README.md @@ -338,7 +338,6 @@ in memory for immediate 3D plotting in Python. - [NumPy >= 1.14.5](https://www.numpy.org) - [SciPy >= 1.1.0](https://www.scipy.org/scipylib) -- [GSTools-Core >= 0.1.0](https://github.com/GeoStat-Framework/GSTools-Core) - [hankel >= 1.0.2](https://github.com/steven-murray/hankel) - [emcee >= 3.0.0](https://github.com/dfm/emcee) - [pyevtk >= 1.1.1](https://github.com/pyscience-projects/pyevtk) @@ -346,6 +345,7 @@ in memory for immediate 3D plotting in Python. ### Optional +- [GSTools-Core >= 0.1.0](https://github.com/GeoStat-Framework/GSTools-Core) - [matplotlib](https://matplotlib.org) - [pyvista](https://docs.pyvista.org/) diff --git a/docs/source/index.rst b/docs/source/index.rst index 1aa563d5..b3696535 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -378,7 +378,6 @@ Requirements - `Numpy >= 1.14.5 `_ - `SciPy >= 1.1.0 `_ -- `GSTools-Core >= 0.1.0 `_ - `hankel >= 1.0.2 `_ - `emcee >= 3.0.0 `_ - `pyevtk >= 1.1.1 `_ @@ -388,6 +387,7 @@ Requirements Optional -------- +- `GSTools-Core >= 0.1.0 `_ - `matplotlib `_ - `pyvista `_ diff --git a/gstools/__init__.py b/gstools/__init__.py index 77281d8d..a82839cc 100644 --- a/gstools/__init__.py +++ b/gstools/__init__.py @@ -127,6 +127,7 @@ """ # Hooray! from gstools import ( + config, covmodel, field, krige, diff --git a/gstools/config.py b/gstools/config.py new file mode 100644 index 00000000..ebbd571c --- /dev/null +++ b/gstools/config.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +""" +GStools subpackage providing global variables. + +.. currentmodule:: gstools.config + +""" +# pylint: disable=W0611 +try: + import gstools_core + + USE_RUST = True +except ImportError: + USE_RUST = False diff --git a/gstools/field/generator.py b/gstools/field/generator.py index f8d7be31..91408e88 100644 --- a/gstools/field/generator.py +++ b/gstools/field/generator.py @@ -10,16 +10,23 @@ RandMeth IncomprRandMeth """ -# pylint: disable=C0103, W0222 +# pylint: disable=C0103, W0222, C0412 import warnings from copy import deepcopy as dcp import numpy as np -from gstools_core import summate, summate_incompr +from gstools import config from gstools.covmodel.base import CovModel from gstools.random.rng import RNG +if config.USE_RUST: + # pylint: disable=E0401 + from gstools_core import summate, summate_incompr +else: + # pylint: disable=C0412 + from gstools.field.summator import summate, summate_incompr + __all__ = ["RandMeth", "IncomprRandMeth"] diff --git a/gstools/field/summator.pyx b/gstools/field/summator.pyx new file mode 100644 index 00000000..ecd7ea58 --- /dev/null +++ b/gstools/field/summator.pyx @@ -0,0 +1,82 @@ +#cython: language_level=3, boundscheck=False, wraparound=False, cdivision=True +# -*- coding: utf-8 -*- +""" +This is the randomization method summator, implemented in cython. +""" + +import numpy as np + +cimport cython + +from cython.parallel import prange + +cimport numpy as np +from libc.math cimport cos, sin + + +def summate( + const double[:, :] cov_samples, + const double[:] z_1, + const double[:] z_2, + const double[:, :] pos + ): + cdef int i, j, d + cdef double phase + cdef int dim = pos.shape[0] + + cdef int X_len = pos.shape[1] + cdef int N = cov_samples.shape[1] + + cdef double[:] summed_modes = np.zeros(X_len, dtype=float) + + for i in prange(X_len, nogil=True): + for j in range(N): + phase = 0. + for d in range(dim): + phase += cov_samples[d, j] * pos[d, i] + summed_modes[i] += z_1[j] * cos(phase) + z_2[j] * sin(phase) + + return np.asarray(summed_modes) + + +cdef (double) abs_square(const double[:] vec) nogil: + cdef int i + cdef double r = 0. + + for i in range(vec.shape[0]): + r += vec[i]**2 + + return r + + +def summate_incompr( + const double[:, :] cov_samples, + const double[:] z_1, + const double[:] z_2, + const double[:, :] pos + ): + cdef int i, j, d + cdef double phase + cdef double k_2 + cdef int dim = pos.shape[0] + + cdef double[:] e1 = np.zeros(dim, dtype=float) + e1[0] = 1. + cdef double[:] proj = np.empty(dim) + + cdef int X_len = pos.shape[1] + cdef int N = cov_samples.shape[1] + + cdef double[:, :] summed_modes = np.zeros((dim, X_len), dtype=float) + + for i in range(X_len): + for j in range(N): + k_2 = abs_square(cov_samples[:, j]) + phase = 0. + for d in range(dim): + phase += cov_samples[d, j] * pos[d, i] + for d in range(dim): + proj[d] = e1[d] - cov_samples[d, j] * cov_samples[0, j] / k_2 + summed_modes[d, i] += proj[d] * (z_1[j] * cos(phase) + z_2[j] * sin(phase)) + + return np.asarray(summed_modes) diff --git a/gstools/krige/base.py b/gstools/krige/base.py index 5558ec7b..e3922d33 100755 --- a/gstools/krige/base.py +++ b/gstools/krige/base.py @@ -9,23 +9,29 @@ .. autosummary:: Krige """ -# pylint: disable=C0103, W0221, E1102, R0201 +# pylint: disable=C0103, W0221, E1102, R0201, C0412 import collections import numpy as np import scipy.linalg as spl from scipy.spatial.distance import cdist -from gstools_core import ( - calc_field_krige, - calc_field_krige_and_variance, -) +from gstools import config from gstools.field.base import Field from gstools.krige.tools import get_drift_functions, set_condition from gstools.tools.geometric import rotated_main_axes from gstools.tools.misc import eval_func from gstools.variogram import vario_estimate +if config.USE_RUST: + # pylint: disable=E0401 + from gstools_core import calc_field_krige, calc_field_krige_and_variance +else: + from gstools.krige.krigesum import ( + calc_field_krige, + calc_field_krige_and_variance, + ) + __all__ = ["Krige"] diff --git a/gstools/krige/krigesum.pyx b/gstools/krige/krigesum.pyx new file mode 100644 index 00000000..41911cef --- /dev/null +++ b/gstools/krige/krigesum.pyx @@ -0,0 +1,64 @@ +#cython: language_level=3, boundscheck=False, wraparound=False, cdivision=True +# -*- coding: utf-8 -*- +""" +This is a summator for the kriging routines +""" + +import numpy as np + +cimport cython +from cython.parallel import prange +cimport numpy as np + + +def calc_field_krige_and_variance( + const double[:, :] krig_mat, + const double[:, :] krig_vecs, + const double[:] cond +): + + cdef int mat_i = krig_mat.shape[0] + cdef int res_i = krig_vecs.shape[1] + + cdef double[:] field = np.zeros(res_i) + cdef double[:] error = np.zeros(res_i) + cdef double krig_fac + + cdef int i, j, k + + # error = krig_vecs * krig_mat * krig_vecs + # field = cond * krig_mat * krig_vecs + for k in prange(res_i, nogil=True): + for i in range(mat_i): + krig_fac = 0.0 + for j in range(mat_i): + krig_fac += krig_mat[i, j] * krig_vecs[j, k] + error[k] += krig_vecs[i, k] * krig_fac + field[k] += cond[i] * krig_fac + + return np.asarray(field), np.asarray(error) + + +def calc_field_krige( + const double[:, :] krig_mat, + const double[:, :] krig_vecs, + const double[:] cond +): + + cdef int mat_i = krig_mat.shape[0] + cdef int res_i = krig_vecs.shape[1] + + cdef double[:] field = np.zeros(res_i) + cdef double krig_fac + + cdef int i, j, k + + # field = cond * krig_mat * krig_vecs + for k in prange(res_i, nogil=True): + for i in range(mat_i): + krig_fac = 0.0 + for j in range(mat_i): + krig_fac += krig_mat[i, j] * krig_vecs[j, k] + field[k] += cond[i] * krig_fac + + return np.asarray(field) diff --git a/gstools/variogram/estimator.pyx b/gstools/variogram/estimator.pyx new file mode 100644 index 00000000..8c149b41 --- /dev/null +++ b/gstools/variogram/estimator.pyx @@ -0,0 +1,346 @@ +#cython: language_level=3, boundscheck=False, wraparound=False, cdivision=True +# distutils: language = c++ +# -*- coding: utf-8 -*- +""" +This is the variogram estimater, implemented in cython. +""" + +import numpy as np + +cimport cython + +from cython.parallel import parallel, prange + +cimport numpy as np +from libc.math cimport M_PI, acos, atan2, cos, fabs, isnan, pow, sin, sqrt + + +cdef inline double dist_euclid( + const int dim, + const double[:,:] pos, + const int i, + const int j, +) nogil: + cdef int d + cdef double dist_squared = 0.0 + for d in range(dim): + dist_squared += ((pos[d,i] - pos[d,j]) * (pos[d,i] - pos[d,j])) + return sqrt(dist_squared) + + +cdef inline double dist_haversine( + const int dim, + const double[:,:] pos, + const int i, + const int j, +) nogil: + # pos holds lat-lon in deg + cdef double deg_2_rad = M_PI / 180.0 + cdef double diff_lat = (pos[0, j] - pos[0, i]) * deg_2_rad + cdef double diff_lon = (pos[1, j] - pos[1, i]) * deg_2_rad + cdef double arg = ( + pow(sin(diff_lat/2.0), 2) + + cos(pos[0, i]*deg_2_rad) * + cos(pos[0, j]*deg_2_rad) * + pow(sin(diff_lon/2.0), 2) + ) + return 2.0 * atan2(sqrt(arg), sqrt(1.0-arg)) + + +ctypedef double (*_dist_func)( + const int, + const double[:,:], + const int, + const int, +) nogil + + +cdef inline bint dir_test( + const int dim, + const double[:,:] pos, + const double dist, + const double[:,:] direction, + const double angles_tol, + const double bandwidth, + const int i, + const int j, + const int d, +) nogil: + cdef double s_prod = 0.0 # scalar product + cdef double b_dist = 0.0 # band-distance + cdef double tmp # temporary variable + cdef int k + cdef bint in_band = True + cdef bint in_angle = True + + # scalar-product calculation for bandwidth projection and angle calculation + for k in range(dim): + s_prod += (pos[k,i] - pos[k,j]) * direction[d,k] + + # calculate band-distance by projection of point-pair-vec to direction line + if bandwidth > 0.0: + for k in range(dim): + tmp = (pos[k,i] - pos[k,j]) - s_prod * direction[d,k] + b_dist += tmp * tmp + in_band = sqrt(b_dist) < bandwidth + + # allow repeating points (dist = 0) + if dist > 0.0: + # use smallest angle by taking absolute value for arccos angle formula + tmp = fabs(s_prod) / dist + if tmp < 1.0: # else same direction (prevent numerical errors) + in_angle = acos(tmp) < angles_tol + + return in_band and in_angle + + +cdef inline double estimator_matheron(const double f_diff) nogil: + return f_diff * f_diff + +cdef inline double estimator_cressie(const double f_diff) nogil: + return sqrt(fabs(f_diff)) + +ctypedef double (*_estimator_func)(const double) nogil + +cdef inline void normalization_matheron( + double[:] variogram, + long[:] counts, +): + cdef int i + for i in range(variogram.shape[0]): + # avoid division by zero + variogram[i] /= (2. * max(counts[i], 1)) + +cdef inline void normalization_cressie( + double[:] variogram, + long[:] counts, +): + cdef int i + cdef long cnt + for i in range(variogram.shape[0]): + # avoid division by zero + cnt = max(counts[i], 1) + variogram[i] = ( + 0.5 * (1./cnt * variogram[i])**4 / + (0.457 + 0.494 / cnt + 0.045 / cnt**2) + ) + +ctypedef void (*_normalization_func)( + double[:], + long[:], +) + +cdef inline void normalization_matheron_vec( + double[:,:] variogram, + long[:,:] counts, +): + cdef int d, i + for d in range(variogram.shape[0]): + normalization_matheron(variogram[d, :], counts[d, :]) + +cdef inline void normalization_cressie_vec( + double[:,:] variogram, + long[:,:] counts, +): + cdef int d, i + cdef long cnt + for d in range(variogram.shape[0]): + normalization_cressie(variogram[d, :], counts[d, :]) + +ctypedef void (*_normalization_func_vec)( + double[:,:], + long[:,:], +) + +cdef _estimator_func choose_estimator_func(str estimator_type): + cdef _estimator_func estimator_func + if estimator_type == 'm': + estimator_func = estimator_matheron + elif estimator_type == 'c': + estimator_func = estimator_cressie + return estimator_func + +cdef _normalization_func choose_estimator_normalization(str estimator_type): + cdef _normalization_func normalization_func + if estimator_type == 'm': + normalization_func = normalization_matheron + elif estimator_type == 'c': + normalization_func = normalization_cressie + return normalization_func + +cdef _normalization_func_vec choose_estimator_normalization_vec(str estimator_type): + cdef _normalization_func_vec normalization_func_vec + if estimator_type == 'm': + normalization_func_vec = normalization_matheron_vec + elif estimator_type == 'c': + normalization_func_vec = normalization_cressie_vec + return normalization_func_vec + + +def directional( + const int dim, + const double[:,:] f, + const double[:] bin_edges, + const double[:,:] pos, + const double[:,:] direction, # should be normed + const double angles_tol=M_PI/8.0, + const double bandwidth=-1.0, # negative values to turn of bandwidth search + const bint separate_dirs=False, # whether the direction bands don't overlap + str estimator_type='m', +): + if pos.shape[1] != f.shape[1]: + raise ValueError('len(pos) = {0} != len(f) = {1} '. + format(pos.shape[1], f.shape[1])) + + if bin_edges.shape[0] < 2: + raise ValueError('len(bin_edges) too small') + + if angles_tol <= 0: + raise ValueError('tolerance for angle search masks must be > 0') + + cdef _estimator_func estimator_func = choose_estimator_func(estimator_type) + cdef _normalization_func_vec normalization_func_vec = ( + choose_estimator_normalization_vec(estimator_type) + ) + + cdef int d_max = direction.shape[0] + cdef int i_max = bin_edges.shape[0] - 1 + cdef int j_max = pos.shape[1] - 1 + cdef int k_max = pos.shape[1] + cdef int f_max = f.shape[0] + + cdef double[:,:] variogram = np.zeros((d_max, len(bin_edges)-1)) + cdef long[:,:] counts = np.zeros((d_max, len(bin_edges)-1), dtype=long) + cdef int i, j, k, m, d + cdef double dist + + for i in prange(i_max, nogil=True): + for j in range(j_max): + for k in range(j+1, k_max): + dist = dist_euclid(dim, pos, j, k) + if dist < bin_edges[i] or dist >= bin_edges[i+1]: + continue # skip if not in current bin + for d in range(d_max): + if not dir_test(dim, pos, dist, direction, angles_tol, bandwidth, k, j, d): + continue # skip if not in current direction + for m in range(f_max): + # skip no data values + if not (isnan(f[m,k]) or isnan(f[m,j])): + counts[d, i] += 1 + variogram[d, i] += estimator_func(f[m,k] - f[m,j]) + # once we found a fitting direction + # break the search if directions are separated + if separate_dirs: + break + + normalization_func_vec(variogram, counts) + return np.asarray(variogram), np.asarray(counts) + +def unstructured( + const int dim, + const double[:,:] f, + const double[:] bin_edges, + const double[:,:] pos, + str estimator_type='m', + str distance_type='e', +): + cdef _dist_func distance + + if distance_type == 'e': + distance = dist_euclid + else: + distance = dist_haversine + if dim != 2: + raise ValueError('Haversine: dim = {0} != 2'.format(dim)) + + if pos.shape[1] != f.shape[1]: + raise ValueError('len(pos) = {0} != len(f) = {1} '. + format(pos.shape[1], f.shape[1])) + + if bin_edges.shape[0] < 2: + raise ValueError('len(bin_edges) too small') + + cdef _estimator_func estimator_func = choose_estimator_func(estimator_type) + cdef _normalization_func normalization_func = ( + choose_estimator_normalization(estimator_type) + ) + + cdef int i_max = bin_edges.shape[0] - 1 + cdef int j_max = pos.shape[1] - 1 + cdef int k_max = pos.shape[1] + cdef int f_max = f.shape[0] + + cdef double[:] variogram = np.zeros(len(bin_edges)-1) + cdef long[:] counts = np.zeros(len(bin_edges)-1, dtype=long) + cdef int i, j, k, m + cdef double dist + + for i in prange(i_max, nogil=True): + for j in range(j_max): + for k in range(j+1, k_max): + dist = distance(dim, pos, j, k) + if dist < bin_edges[i] or dist >= bin_edges[i+1]: + continue # skip if not in current bin + for m in range(f_max): + # skip no data values + if not (isnan(f[m,k]) or isnan(f[m,j])): + counts[i] += 1 + variogram[i] += estimator_func(f[m,k] - f[m,j]) + + normalization_func(variogram, counts) + return np.asarray(variogram), np.asarray(counts) + + +def structured(const double[:,:] f, str estimator_type='m'): + cdef _estimator_func estimator_func = choose_estimator_func(estimator_type) + cdef _normalization_func normalization_func = ( + choose_estimator_normalization(estimator_type) + ) + + cdef int i_max = f.shape[0] - 1 + cdef int j_max = f.shape[1] + cdef int k_max = i_max + 1 + + cdef double[:] variogram = np.zeros(k_max) + cdef long[:] counts = np.zeros(k_max, dtype=long) + cdef int i, j, k + + with nogil, parallel(): + for i in range(i_max): + for j in range(j_max): + for k in prange(1, k_max-i): + counts[k] += 1 + variogram[k] += estimator_func(f[i,j] - f[i+k,j]) + + normalization_func(variogram, counts) + return np.asarray(variogram) + + +def ma_structured( + const double[:,:] f, + const bint[:,:] mask, + str estimator_type='m', +): + cdef _estimator_func estimator_func = choose_estimator_func(estimator_type) + cdef _normalization_func normalization_func = ( + choose_estimator_normalization(estimator_type) + ) + + cdef int i_max = f.shape[0] - 1 + cdef int j_max = f.shape[1] + cdef int k_max = i_max + 1 + + cdef double[:] variogram = np.zeros(k_max) + cdef long[:] counts = np.zeros(k_max, dtype=long) + cdef int i, j, k + + with nogil, parallel(): + for i in range(i_max): + for j in range(j_max): + for k in prange(1, k_max-i): + if not mask[i,j] and not mask[i+k,j]: + counts[k] += 1 + variogram[k] += estimator_func(f[i,j] - f[i+k,j]) + + normalization_func(variogram, counts) + return np.asarray(variogram) diff --git a/gstools/variogram/variogram.py b/gstools/variogram/variogram.py index d758e3b4..e24255eb 100644 --- a/gstools/variogram/variogram.py +++ b/gstools/variogram/variogram.py @@ -10,14 +10,10 @@ vario_estimate vario_estimate_axis """ +# pylint: disable=C0412 import numpy as np -from gstools_core import ( - variogram_directional as directional, - variogram_ma_structured as ma_structured, - variogram_structured as structured, - variogram_unstructured as unstructured, -) +from gstools import config from gstools.normalizer.tools import remove_trend_norm_mean from gstools.tools.geometric import ( ang2dir, @@ -27,6 +23,21 @@ ) from gstools.variogram.binning import standard_bins +if config.USE_RUST: + # pylint: disable=E0401 + from gstools_core import variogram_directional as directional + from gstools_core import variogram_ma_structured as ma_structured + from gstools_core import variogram_structured as structured + from gstools_core import variogram_unstructured as unstructured +else: + # pylint: disable=C0412 + from gstools.variogram.estimator import ( + directional, + ma_structured, + structured, + unstructured, + ) + __all__ = [ "vario_estimate", "vario_estimate_axis", @@ -443,6 +454,8 @@ def vario_estimate_axis( if missing: field.mask = np.logical_or(field.mask, missing_mask) mask = np.ma.getmaskarray(field) + if not config.USE_RUST: + mask = np.asarray(mask, dtype=np.int32) else: field = np.array(field, ndmin=1, dtype=np.double, copy=False) missing_mask = None # free space diff --git a/setup.cfg b/setup.cfg index d533120b..116b51ba 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,7 +39,6 @@ project_urls = packages = find: install_requires = emcee>=3.0.0,<4 - gstools_core>=0.1.2,<1 hankel>=1.0.2,<2 meshio>=4.0.3,<6 numpy>=1.14.5,<2 @@ -67,6 +66,8 @@ doc = plotting = matplotlib>=3,<4 pyvista>=0.29,<1 +rust = + gstools_core>=0.1.2,<1 test = coverage[toml]>=5.2.1,<6 pytest>=6.0,<7 diff --git a/setup.py b/setup.py index 291d34a5..1bbc7faf 100644 --- a/setup.py +++ b/setup.py @@ -1,14 +1,169 @@ # -*- coding: utf-8 -*- """GSTools: A geostatistical toolbox.""" +import glob import os - -import numpy as np -from setuptools import setup +import subprocess import sys -import site +import tempfile +from distutils.ccompiler import new_compiler +from distutils.errors import CompileError, LinkError +from distutils.sysconfig import customize_compiler -site.ENABLE_USER_SITE = "--user" in sys.argv[1:] +import numpy as np +from Cython.Build import cythonize +from setuptools import Extension, setup HERE = os.path.abspath(os.path.dirname(__file__)) -setup(include_dirs=[np.get_include()]) + +# openmp finder ############################################################### +# This code is adapted for a large part from the scikit-learn openmp_helpers.py +# which can be found at: +# https://github.com/scikit-learn/scikit-learn/blob/0.24.0/sklearn/_build_utils + + +CCODE = """ +#include +#include +int main(void) { +#pragma omp parallel +printf("nthreads=%d\\n", omp_get_num_threads()); +return 0; +} +""" + + +def get_openmp_flag(compiler): + """Get the compiler dependent openmp flag.""" + if hasattr(compiler, "compiler"): + compiler = compiler.compiler[0] + else: + compiler = compiler.__class__.__name__ + + if sys.platform == "win32" and ("icc" in compiler or "icl" in compiler): + return ["/Qopenmp"] + if sys.platform == "win32": + return ["/openmp"] + if sys.platform == "darwin" and ("icc" in compiler or "icl" in compiler): + return ["-openmp"] + if sys.platform == "darwin" and "openmp" in os.getenv("CPPFLAGS", ""): + return [] + # Default flag for GCC and clang: + return ["-fopenmp"] + + +def check_openmp_support(): + """Check whether OpenMP test code can be compiled and run.""" + ccompiler = new_compiler() + customize_compiler(ccompiler) + + with tempfile.TemporaryDirectory() as tmp_dir: + try: + os.chdir(tmp_dir) + # Write test program + with open("test_openmp.c", "w") as cfile: + cfile.write(CCODE) + os.mkdir("objects") + # Compile, test program + openmp_flags = get_openmp_flag(ccompiler) + ccompiler.compile( + ["test_openmp.c"], + output_dir="objects", + extra_postargs=openmp_flags, + ) + # Link test program + extra_preargs = os.getenv("LDFLAGS", None) + if extra_preargs is not None: + extra_preargs = extra_preargs.split(" ") + else: + extra_preargs = [] + objects = glob.glob( + os.path.join("objects", "*" + ccompiler.obj_extension) + ) + ccompiler.link_executable( + objects, + "test_openmp", + extra_preargs=extra_preargs, + extra_postargs=openmp_flags, + ) + # Run test program + output = subprocess.check_output("./test_openmp") + output = output.decode(sys.stdout.encoding or "utf-8").splitlines() + # Check test program output + if "nthreads=" in output[0]: + nthreads = int(output[0].strip().split("=")[1]) + openmp_supported = len(output) == nthreads + else: + openmp_supported = False + openmp_flags = [] + except (CompileError, LinkError, subprocess.CalledProcessError): + openmp_supported = False + openmp_flags = [] + finally: + os.chdir(HERE) + return openmp_supported, openmp_flags + + +# openmp ###################################################################### + + +# you can set GSTOOLS_BUILD_PARALLEL=0 or GSTOOLS_BUILD_PARALLEL=1 +GS_PARALLEL = os.getenv("GSTOOLS_BUILD_PARALLEL") +USE_OPENMP = bool(int(GS_PARALLEL)) if GS_PARALLEL else False + +if USE_OPENMP: + # just check if wanted + CAN_USE_OPENMP, FLAGS = check_openmp_support() + if CAN_USE_OPENMP: + print("## GSTOOLS setup: OpenMP found.") + print("## OpenMP flags:", FLAGS) + else: + print("## GSTOOLS setup: OpenMP not found.") +else: + print("## GSTOOLS setup: OpenMP not wanted by the user.") + FLAGS = [] + + +# cython extensions ########################################################### + + +CY_MODULES = [] +CY_MODULES.append( + Extension( + "gstools.field.summator", + [os.path.join("gstools", "field", "summator.pyx")], + include_dirs=[np.get_include()], + extra_compile_args=FLAGS, + extra_link_args=FLAGS, + ) +) +CY_MODULES.append( + Extension( + "gstools.variogram.estimator", + [os.path.join("gstools", "variogram", "estimator.pyx")], + language="c++", + include_dirs=[np.get_include()], + extra_compile_args=FLAGS, + extra_link_args=FLAGS, + ) +) +CY_MODULES.append( + Extension( + "gstools.krige.krigesum", + [os.path.join("gstools", "krige", "krigesum.pyx")], + include_dirs=[np.get_include()], + extra_compile_args=FLAGS, + extra_link_args=FLAGS, + ) +) +EXT_MODULES = cythonize(CY_MODULES) # annotate=True + +# embed signatures for sphinx +for ext_m in EXT_MODULES: + ext_m.cython_directives = {"embedsignature": True} + + +# setup ####################################################################### + + +setup(ext_modules=EXT_MODULES, include_dirs=[np.get_include()])