From bcf4b12d73bd7e2dda403f89711cc39fbfacbd4c Mon Sep 17 00:00:00 2001 From: Ryan Hausen Date: Thu, 5 Sep 2024 09:25:46 -0400 Subject: [PATCH 1/4] Update developer docs to ensure compatibility with latest scikit-learn. (#319) * updated developer documentation to install latest sklearn for compatability --- DEVELOPING.md | 8 ++++++-- build_sklearn_requirements.txt | 3 +++ 2 files changed, 9 insertions(+), 2 deletions(-) create mode 100644 build_sklearn_requirements.txt diff --git a/DEVELOPING.md b/DEVELOPING.md index 9340552f..f160df0e 100644 --- a/DEVELOPING.md +++ b/DEVELOPING.md @@ -44,6 +44,10 @@ If you are developing locally, you will need the build dependencies to compile t pip install -r build_requirements.txt +Additionally, you need to install the latest build of scikit-learn: + + pip install --force -r build_sklearn_requirements.txt + Other requirements can be installed as such: pip install . @@ -132,7 +136,7 @@ You can also do the same thing using Meson/Ninja itself. Run the following to bu export PYTHONPATH=${PWD}/build/lib/python/site-packages # to check installation, you need to be in a different directory - cd docs; + cd docs; python -c "from treeple import tree" python -c "import sklearn; print(sklearn.__version__);" @@ -188,7 +192,7 @@ GH Actions will build wheels for each Python version and OS. Then the wheels nee will have all the wheels for common OSes built for each Python version. 2. Upload wheels to test PyPi -This is to ensure that the wheels are built correctly and can be installed on a fresh environment. For more information, see . You will need to follow the instructions to create an account and get your API token for testpypi and pypi. +This is to ensure that the wheels are built correctly and can be installed on a fresh environment. For more information, see . You will need to follow the instructions to create an account and get your API token for testpypi and pypi. ``` twine upload dist/* --repository testpypi diff --git a/build_sklearn_requirements.txt b/build_sklearn_requirements.txt new file mode 100644 index 00000000..8033f54d --- /dev/null +++ b/build_sklearn_requirements.txt @@ -0,0 +1,3 @@ +--pre +--extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple +scikit-learn \ No newline at end of file From ea67d06f937fa3ab58799bbe5f63840ee4a6f538 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Sep 2024 09:26:13 -0400 Subject: [PATCH 2/4] [pre-commit.ci] pre-commit autoupdate (#320) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.6.1 → v0.6.3](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.1...v0.6.3) - [github.com/astral-sh/ruff-pre-commit: v0.6.1 → v0.6.3](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.1...v0.6.3) - [github.com/pre-commit/mirrors-mypy: v1.11.1 → v1.11.2](https://github.com/pre-commit/mirrors-mypy/compare/v1.11.1...v1.11.2) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6fd0aecb..e744a306 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,7 @@ repos: # Ruff treeple - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.1 + rev: v0.6.3 hooks: - id: ruff name: ruff treeple @@ -31,7 +31,7 @@ repos: # Ruff tutorials and examples - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.1 + rev: v0.6.3 hooks: - id: ruff name: ruff tutorials and examples @@ -67,7 +67,7 @@ repos: # mypy - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.1 + rev: v1.11.2 hooks: - id: mypy # Avoid the conflict between mne/__init__.py and mne/__init__.pyi by ignoring the former From 7e9dc2269e8fb237240a996bc4ed6ef6c9e966ec Mon Sep 17 00:00:00 2001 From: Adam Li Date: Mon, 9 Sep 2024 12:19:35 -0400 Subject: [PATCH 3/4] MAINT Clean up Cython files (#321) * Clean up Cython files in oblique and morf splitter * Migrate `self._validate_data` to `validate_data` in scikit-learn developer API * Update spin to v0.12+ * Update c++ to c++11 standard --------- Signed-off-by: Adam Li --- .spin/cmds.py | 5 ++ build_requirements.txt | 2 +- meson.build | 2 +- pyproject.toml | 2 +- treeple/__init__.py | 4 +- treeple/_lib/meson.build | 19 ++++++ treeple/_lib/sklearn_fork | 2 +- treeple/ensemble/_honest_forest.py | 8 ++- treeple/ensemble/_unsupervised_forest.py | 12 ++-- treeple/meson.build | 1 + treeple/neighbors.py | 15 +++-- treeple/tree/_classes.py | 82 ++++++++++++++++++------ treeple/tree/_neighbors.py | 4 -- treeple/tree/_oblique_splitter.pxd | 6 -- treeple/tree/_oblique_splitter.pyx | 42 ++++-------- treeple/tree/_utils.pxd | 38 +++++++++-- treeple/tree/_utils.pyx | 65 ++++++++++++++----- treeple/tree/manifold/_morf_splitter.pxd | 10 +-- treeple/tree/manifold/_morf_splitter.pyx | 14 ++-- 19 files changed, 224 insertions(+), 109 deletions(-) diff --git a/.spin/cmds.py b/.spin/cmds.py index 7a80393d..b5631b0e 100644 --- a/.spin/cmds.py +++ b/.spin/cmds.py @@ -5,6 +5,7 @@ import click from spin import util from spin.cmds import meson +from spin.cmds.meson import build_dir_option def get_git_revision_hash(submodule) -> str: @@ -145,14 +146,18 @@ def setup_submodule(forcesubmodule=False): @click.option( "--forcesubmodule", is_flag=True, help="Force submodule pull.", envvar="FORCE_SUBMODULE" ) +@build_dir_option @click.pass_context def build( ctx, + *, meson_args, jobs=None, clean=False, verbose=False, gcov=False, + quiet=False, + build_dir=None, forcesubmodule=False, ): """Build treeple using submodules. diff --git a/build_requirements.txt b/build_requirements.txt index 95bc6c98..ec63cfb3 100644 --- a/build_requirements.txt +++ b/build_requirements.txt @@ -8,5 +8,5 @@ click rich-click doit pydevtool -spin +spin>=0.12 build diff --git a/meson.build b/meson.build index 26f909de..07ec4c9c 100644 --- a/meson.build +++ b/meson.build @@ -8,7 +8,7 @@ project( license: 'PolyForm Noncommercial 1.0.0', meson_version: '>= 1.1.0', default_options: [ - 'c_std=c99', + 'c_std=c11', 'cpp_std=c++14', ], ) diff --git a/pyproject.toml b/pyproject.toml index 596d2408..c0a50d95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,7 @@ build = [ 'twine', 'meson', 'meson-python', - 'spin', + 'spin>=0.12', 'doit', 'scikit-learn>=1.5.0', 'Cython>=3.0.10', diff --git a/treeple/__init__.py b/treeple/__init__.py index 2a70afef..dafad7de 100644 --- a/treeple/__init__.py +++ b/treeple/__init__.py @@ -22,6 +22,7 @@ # https://github.com/ContinuumIO/anaconda-issues/issues/11294 os.environ.setdefault("KMP_INIT_AT_FORK", "FALSE") + try: # This variable is injected in the __builtins__ by the build # process. It is used to enable importing subpackages of sklearn when @@ -64,7 +65,8 @@ msg = """Error importing treeple: you cannot import treeple while being in treeple source directory; please exit the treeple source tree first and relaunch your Python interpreter.""" - raise ImportError(msg) from e + raise Exception(e) + # raise ImportError(msg) from e __all__ = [ "_lib", diff --git a/treeple/_lib/meson.build b/treeple/_lib/meson.build index 5dd37c86..ae83cf4a 100644 --- a/treeple/_lib/meson.build +++ b/treeple/_lib/meson.build @@ -94,3 +94,22 @@ foreach ext: extensions subdir: 'treeple/_lib/sklearn/utils/', ) endforeach + + +# python_sources = [ +# '__init__.py', +# ] + +# py.install_sources( +# python_sources, +# subdir: 'treeple/_lib' # Folder relative to site-packages to install to +# ) + +# tempita = files('./sklearn/_build_utils/tempita.py') + +# # Copy all the .py files to the install dir, rather than using +# # py.install_sources and needing to list them explicitely one by one +# # install_subdir('sklearn', install_dir: py.get_install_dir()) +# install_subdir('sklearn', install_dir: join_paths(py.get_install_dir(), 'treeple/_lib')) + +# subdir('sklearn') diff --git a/treeple/_lib/sklearn_fork b/treeple/_lib/sklearn_fork index ac5cb8ab..e4b9728c 160000 --- a/treeple/_lib/sklearn_fork +++ b/treeple/_lib/sklearn_fork @@ -1 +1 @@ -Subproject commit ac5cb8abd5c9b425c3c02a2be1d91296adf643a3 +Subproject commit e4b9728cb8667d0a40ed0c6c45f0414811f5f1f8 diff --git a/treeple/ensemble/_honest_forest.py b/treeple/ensemble/_honest_forest.py index 96c01062..447371b3 100644 --- a/treeple/ensemble/_honest_forest.py +++ b/treeple/ensemble/_honest_forest.py @@ -720,8 +720,12 @@ def oob_samples_(self): oob_samples.append(_oob_samples) return oob_samples - def _more_tags(self): - return {"multioutput": False} + def __sklearn_tags__(self): + # XXX: nans should be supportable in HRF + tags = super().__sklearn_tags__() + tags.classifier_tags.multi_output = False + tags.input_tags.allow_nan = False + return tags def decision_path(self, X): """ diff --git a/treeple/ensemble/_unsupervised_forest.py b/treeple/ensemble/_unsupervised_forest.py index a66c330a..980c1ebb 100644 --- a/treeple/ensemble/_unsupervised_forest.py +++ b/treeple/ensemble/_unsupervised_forest.py @@ -21,7 +21,12 @@ ) from sklearn.metrics import calinski_harabasz_score from sklearn.utils.parallel import Parallel, delayed -from sklearn.utils.validation import _check_sample_weight, check_is_fitted, check_random_state +from sklearn.utils.validation import ( + _check_sample_weight, + check_is_fitted, + check_random_state, + validate_data, +) from .._lib.sklearn.ensemble._forest import BaseForest from .._lib.sklearn.tree._tree import DTYPE @@ -85,10 +90,9 @@ def fit(self, X, y=None, sample_weight=None): self : object Returns the instance itself. """ - self._validate_params() - # Validate or convert input data - X = self._validate_data( + X = validate_data( + self, X, dtype=DTYPE, # accept_sparse="csc", ) diff --git a/treeple/meson.build b/treeple/meson.build index 3d1715db..4801d053 100644 --- a/treeple/meson.build +++ b/treeple/meson.build @@ -103,6 +103,7 @@ scikit_learn_cython_args = [ '-X language_level=3', '-X boundscheck=' + boundscheck, '-X wraparound=False', '-X initializedcheck=False', '-X nonecheck=False', '-X cdivision=True', '-X profile=False', + '-X embedsignature=True', # Needed for cython imports across subpackages, e.g. cluster pyx that # cimports metrics pxd '--include-dir', meson.global_build_root(), diff --git a/treeple/neighbors.py b/treeple/neighbors.py index 473b4363..b16e732f 100644 --- a/treeple/neighbors.py +++ b/treeple/neighbors.py @@ -5,8 +5,9 @@ from sklearn.base import BaseEstimator, MetaEstimatorMixin from sklearn.exceptions import NotFittedError from sklearn.neighbors import NearestNeighbors -from sklearn.utils.validation import check_is_fitted +from sklearn.utils.validation import check_is_fitted, validate_data +from treeple.tree import DecisionTreeClassifier from treeple.tree._neighbors import _compute_distance_matrix, compute_forest_similarity_matrix @@ -31,13 +32,19 @@ class NearestNeighborsMetaEstimator(BaseEstimator, MetaEstimatorMixin): The number of parallel jobs to run for neighbors, by default None. """ - def __init__(self, estimator, n_neighbors=5, radius=1.0, algorithm="auto", n_jobs=None): + def __init__(self, estimator=None, n_neighbors=5, radius=1.0, algorithm="auto", n_jobs=None): self.estimator = estimator self.n_neighbors = n_neighbors self.algorithm = algorithm self.radius = radius self.n_jobs = n_jobs + def get_estimator(self): + if self.estimator is not None: + return DecisionTreeClassifier(random_state=0) + else: + return copy(self.estimator) + def fit(self, X, y=None): """Fit the nearest neighbors estimator from the training dataset. @@ -56,9 +63,9 @@ def fit(self, X, y=None): self : object Fitted estimator. """ - X, y = self._validate_data(X, y, accept_sparse="csc") + X, y = validate_data(self, X, y, accept_sparse="csc") - self.estimator_ = copy(self.estimator) + self.estimator_ = self.get_estimator() try: check_is_fitted(self.estimator_) except NotFittedError: diff --git a/treeple/tree/_classes.py b/treeple/tree/_classes.py index 16eb6ea5..aa93d4c0 100644 --- a/treeple/tree/_classes.py +++ b/treeple/tree/_classes.py @@ -8,7 +8,7 @@ from sklearn.cluster import AgglomerativeClustering from sklearn.utils import check_random_state from sklearn.utils._param_validation import Interval -from sklearn.utils.validation import check_is_fitted +from sklearn.utils.validation import check_is_fitted, validate_data from .._lib.sklearn.tree import ( BaseDecisionTree, @@ -216,7 +216,7 @@ def fit(self, X, y=None, sample_weight=None, check_input=True): if check_input: # TODO: allow X to be sparse check_X_params = dict(dtype=DTYPE) # , accept_sparse="csc" - X = self._validate_data(X, validate_separately=(check_X_params)) + X = validate_data(self, X, validate_separately=(check_X_params)) if issparse(X): X.sort_indices() @@ -378,6 +378,13 @@ def _assign_labels(self, affinity_matrix): predict_labels = cluster.fit_predict(affinity_matrix) return predict_labels + def __sklearn_tags__(self): + # XXX: nans should be supportable in SPORF by just using RF-like splits on missing values + # However, for MORF it is not supported + tags = super().__sklearn_tags__() + tags.input_tags.allow_nan = False + return tags + class UnsupervisedObliqueDecisionTree(UnsupervisedDecisionTree): """Unsupervised oblique decision tree. @@ -577,6 +584,13 @@ def _build_tree( builder.build(self.tree_, X, sample_weight) return self + def __sklearn_tags__(self): + # XXX: nans should be supportable in SPORF by just using RF-like splits on missing values + # However, for MORF it is not supported + tags = super().__sklearn_tags__() + tags.input_tags.allow_nan = False + return tags + class ObliqueDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier): """An oblique decision tree classifier. @@ -820,7 +834,7 @@ class ObliqueDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier): tree_type = "oblique" - _parameter_constraints = { + _parameter_constraints: dict = { **DecisionTreeClassifier._parameter_constraints, "feature_combinations": [ Interval(Real, 1.0, None, closed="left"), @@ -1070,6 +1084,13 @@ def _update_tree(self, X, y, sample_weight): self._prune_tree() return self + def __sklearn_tags__(self): + # XXX: nans should be supportable in SPORF by just using RF-like splits on missing values + # However, for MORF it is not supported + tags = super().__sklearn_tags__() + tags.input_tags.allow_nan = False + return tags + class ObliqueDecisionTreeRegressor(SimMatrixMixin, DecisionTreeRegressor): """An oblique decision tree Regressor. @@ -1283,7 +1304,7 @@ class ObliqueDecisionTreeRegressor(SimMatrixMixin, DecisionTreeRegressor): tree_type = "oblique" - _parameter_constraints = { + _parameter_constraints: dict = { **DecisionTreeRegressor._parameter_constraints, "feature_combinations": [ Interval(Real, 1.0, None, closed="left"), @@ -1450,6 +1471,13 @@ def _build_tree( builder.build(self.tree_, X, y, sample_weight, None) return self + def __sklearn_tags__(self): + # XXX: nans should be supportable in SPORF by just using RF-like splits on missing values + # However, for MORF it is not supported + tags = super().__sklearn_tags__() + tags.input_tags.allow_nan = False + return tags + class PatchObliqueDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier): """A oblique decision tree classifier that operates over patches of data. @@ -1684,7 +1712,7 @@ class PatchObliqueDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier) """ tree_type = "oblique" - _parameter_constraints = { + _parameter_constraints: dict = { **DecisionTreeClassifier._parameter_constraints, "min_patch_dims": ["array-like", None], "max_patch_dims": ["array-like", None], @@ -1798,8 +1826,8 @@ def _build_tree( self.feature_combinations_ = 1 if self.feature_weight is not None: - self.feature_weight = self._validate_data( - self.feature_weight, ensure_2d=True, dtype=DTYPE + self.feature_weight = validate_data( + self, self.feature_weight, ensure_2d=True, dtype=DTYPE ) if self.feature_weight.shape != X.shape: raise ValueError( @@ -1927,11 +1955,13 @@ def _build_tree( return self - def _more_tags(self): + def __sklearn_tags__(self): # XXX: nans should be supportable in SPORF by just using RF-like splits on missing values # However, for MORF it is not supported - allow_nan = False - return {"multilabel": True, "allow_nan": allow_nan} + tags = super().__sklearn_tags__() + tags.classifier_tags.multi_label = True + tags.input_tags.allow_nan = False + return tags @property def _inheritable_fitted_attribute(self): @@ -2166,7 +2196,7 @@ class PatchObliqueDecisionTreeRegressor(SimMatrixMixin, DecisionTreeRegressor): """ tree_type = "oblique" - _parameter_constraints = { + _parameter_constraints: dict = { **DecisionTreeRegressor._parameter_constraints, "min_patch_dims": ["array-like", None], "max_patch_dims": ["array-like", None], @@ -2277,8 +2307,8 @@ def _build_tree( self.feature_combinations_ = 1 if self.feature_weight is not None: - self.feature_weight = self._validate_data( - self.feature_weight, ensure_2d=True, dtype=DTYPE + self.feature_weight = validate_data( + self, self.feature_weight, ensure_2d=True, dtype=DTYPE ) if self.feature_weight.shape != X.shape: raise ValueError( @@ -2407,11 +2437,13 @@ def _build_tree( return self - def _more_tags(self): + def __sklearn_tags__(self): # XXX: nans should be supportable in SPORF by just using RF-like splits on missing values # However, for MORF it is not supported - allow_nan = False - return {"multilabel": True, "allow_nan": allow_nan} + tags = super().__sklearn_tags__() + tags.regressor_tags.multi_label = True + tags.input_tags.allow_nan = False + return tags class ExtraObliqueDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier): @@ -2669,7 +2701,7 @@ class ExtraObliqueDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier) tree_type = "oblique" - _parameter_constraints = { + _parameter_constraints: dict = { **DecisionTreeClassifier._parameter_constraints, "feature_combinations": [ Interval(Real, 1.0, None, closed="left"), @@ -2846,6 +2878,13 @@ def _inheritable_fitted_attribute(self): "feature_combinations_", ] + def __sklearn_tags__(self): + # XXX: nans should be supportable in SPORF by just using RF-like splits on missing values + # However, for MORF it is not supported + tags = super().__sklearn_tags__() + tags.input_tags.allow_nan = False + return tags + class ExtraObliqueDecisionTreeRegressor(SimMatrixMixin, DecisionTreeRegressor): """An oblique decision tree Regressor. @@ -3069,7 +3108,7 @@ class ExtraObliqueDecisionTreeRegressor(SimMatrixMixin, DecisionTreeRegressor): -0.26552594, -0.00642017, -0.07108117, -0.40726765, -0.40315294]) """ - _parameter_constraints = { + _parameter_constraints: dict = { **DecisionTreeRegressor._parameter_constraints, "feature_combinations": [ Interval(Real, 1.0, None, closed="left"), @@ -3237,3 +3276,10 @@ def _build_tree( builder.build(self.tree_, X, y, sample_weight) return self + + def __sklearn_tags__(self): + # XXX: nans should be supportable in SPORF by just using RF-like splits on missing values + # However, for MORF it is not supported + tags = super().__sklearn_tags__() + tags.input_tags.allow_nan = False + return tags diff --git a/treeple/tree/_neighbors.py b/treeple/tree/_neighbors.py index 94f2c8f1..93d8ff1a 100644 --- a/treeple/tree/_neighbors.py +++ b/treeple/tree/_neighbors.py @@ -64,7 +64,3 @@ def compute_similarity_matrix(self, X): The similarity matrix among the samples. """ return compute_forest_similarity_matrix(self, X) - - def _more_tags(self): - # XXX: no treeple estimators support NaNs as of now - return {"allow_nan": False} diff --git a/treeple/tree/_oblique_splitter.pxd b/treeple/tree/_oblique_splitter.pxd index 124a66dd..65ca16e1 100644 --- a/treeple/tree/_oblique_splitter.pxd +++ b/treeple/tree/_oblique_splitter.pxd @@ -83,12 +83,6 @@ cdef class BaseObliqueSplitter(Splitter): SplitRecord* split, ) except -1 nogil - cdef inline void fisher_yates_shuffle_memview( - self, - intp_t[::1] indices_to_sample, - intp_t grid_size, - uint32_t* random_state - ) noexcept nogil cdef class ObliqueSplitter(BaseObliqueSplitter): # The splitter searches in the input space for a linear combination of features and a threshold diff --git a/treeple/tree/_oblique_splitter.pyx b/treeple/tree/_oblique_splitter.pyx index ca77a30a..0cceac66 100644 --- a/treeple/tree/_oblique_splitter.pyx +++ b/treeple/tree/_oblique_splitter.pyx @@ -11,6 +11,7 @@ from libcpp.vector cimport vector from .._lib.sklearn.tree._criterion cimport Criterion from .._lib.sklearn.tree._utils cimport rand_int, rand_uniform +from ._utils cimport fisher_yates_shuffle cdef float64_t INFINITY = np.inf @@ -46,8 +47,12 @@ cdef class BaseObliqueSplitter(Splitter): def __setstate__(self, d): pass - cdef int node_reset(self, intp_t start, intp_t end, - float64_t* weighted_n_node_samples) except -1 nogil: + cdef int node_reset( + self, + intp_t start, + intp_t end, + float64_t* weighted_n_node_samples + ) except -1 nogil: """Reset splitter on node samples[start:end]. Returns -1 in case of failure to allocate memory (and raise MemoryError) @@ -62,17 +67,7 @@ cdef class BaseObliqueSplitter(Splitter): weighted_n_node_samples : ndarray, dtype=float64_t pointer The total weight of those samples """ - - self.start = start - self.end = end - - self.criterion.init(self.y, - self.sample_weight, - self.weighted_n_samples, - self.samples) - self.criterion.set_sample_pointers(start, end) - - weighted_n_node_samples[0] = self.criterion.weighted_n_node_samples + Splitter.node_reset(self, start, end, weighted_n_node_samples) # Clear all projection vectors for i in range(self.max_features): @@ -102,8 +97,8 @@ cdef class BaseObliqueSplitter(Splitter): intp_t end, const intp_t[:] samples, float32_t[:] feature_values, - vector[float32_t]* proj_vec_weights, # weights of the vector (max_features,) - vector[intp_t]* proj_vec_indices # indices of the features (max_features,) + vector[float32_t]* proj_vec_weights, # weights of the vector (n_non_zeros,) + vector[intp_t]* proj_vec_indices # indices of the features (n_non_zeros,) ) noexcept nogil: """Compute the feature values for the samples[start:end] range. @@ -126,19 +121,6 @@ cdef class BaseObliqueSplitter(Splitter): feature_values[idx] = 0.0 feature_values[idx] += self.X[samples[idx], col_idx] * col_weight - cdef inline void fisher_yates_shuffle_memview( - self, - intp_t[::1] indices_to_sample, - intp_t grid_size, - uint32_t* random_state, - ) noexcept nogil: - cdef intp_t i, j - - # XXX: should this be `i` or `i+1`? for valid Fisher-Yates? - for i in range(0, grid_size - 1): - j = rand_int(i, grid_size, random_state) - indices_to_sample[j], indices_to_sample[i] = \ - indices_to_sample[i], indices_to_sample[j] cdef class ObliqueSplitter(BaseObliqueSplitter): def __cinit__( @@ -257,7 +239,7 @@ cdef class ObliqueSplitter(BaseObliqueSplitter): cdef intp_t grid_size = self.max_features * self.n_features # shuffle indices over the 2D grid to sample using Fisher-Yates - self.fisher_yates_shuffle_memview(indices_to_sample, grid_size, random_state) + fisher_yates_shuffle(indices_to_sample, grid_size, random_state) # sample 'n_non_zeros' in a mtry X n_features projection matrix # which consists of +/- 1's chosen at a 1/2s rate @@ -309,7 +291,7 @@ cdef class BestObliqueSplitter(ObliqueSplitter): cdef intp_t end = self.end # pointer array to store feature values to split on - cdef float32_t[::1] feature_values = self.feature_values + cdef float32_t[::1] feature_values = self.feature_values cdef intp_t max_features = self.max_features cdef intp_t min_samples_leaf = self.min_samples_leaf diff --git a/treeple/tree/_utils.pxd b/treeple/tree/_utils.pxd index c814cc16..ba270779 100644 --- a/treeple/tree/_utils.pxd +++ b/treeple/tree/_utils.pxd @@ -1,3 +1,5 @@ +from libcpp.vector cimport vector + import numpy as np cimport numpy as cnp @@ -7,15 +9,41 @@ cnp.import_array() from .._lib.sklearn.tree._splitter cimport SplitRecord from .._lib.sklearn.utils._typedefs cimport float32_t, float64_t, int32_t, intp_t, uint32_t +ctypedef fused vector_or_memview: + vector[intp_t] + intp_t[::1] + intp_t[:] + + +cdef void fisher_yates_shuffle( + vector_or_memview indices_to_sample, + intp_t grid_size, + uint32_t* random_state, +) noexcept nogil -cdef int rand_weighted_binary(float64_t p0, uint32_t* random_state) noexcept nogil + +cdef int rand_weighted_binary( + float64_t p0, + uint32_t* random_state +) noexcept nogil cpdef unravel_index( - intp_t index, cnp.ndarray[intp_t, ndim=1] shape + intp_t index, + cnp.ndarray[intp_t, ndim=1] shape ) -cpdef ravel_multi_index(intp_t[:] coords, const intp_t[:] shape) +cpdef ravel_multi_index( + intp_t[:] coords, + const intp_t[:] shape +) -cdef void unravel_index_cython(intp_t index, const intp_t[:] shape, intp_t[:] coords) noexcept nogil +cdef void unravel_index_cython( + intp_t index, + const intp_t[:] shape, + vector_or_memview coords +) noexcept nogil -cdef intp_t ravel_multi_index_cython(intp_t[:] coords, const intp_t[:] shape) noexcept nogil +cdef intp_t ravel_multi_index_cython( + vector_or_memview coords, + const intp_t[:] shape +) noexcept nogil diff --git a/treeple/tree/_utils.pyx b/treeple/tree/_utils.pyx index 197b82ec..7ce48977 100644 --- a/treeple/tree/_utils.pyx +++ b/treeple/tree/_utils.pyx @@ -11,10 +11,40 @@ cimport numpy as cnp cnp.import_array() -from .._lib.sklearn.tree._utils cimport rand_uniform +from .._lib.sklearn.tree._utils cimport rand_int, rand_uniform -cdef inline int rand_weighted_binary(float64_t p0, uint32_t* random_state) noexcept nogil: +cdef inline void fisher_yates_shuffle( + vector_or_memview indices_to_sample, + intp_t grid_size, + uint32_t* random_state, +) noexcept nogil: + """Shuffle the indices in place using the Fisher-Yates algorithm. + Parameters + ---------- + indices_to_sample : A C++ vector or 1D memoryview + The indices to shuffle. + grid_size : intp_t + The size of the grid to shuffle. This is explicitly passed in + to support the templated `vector_or_memview` type, which allows + for both C++ vectors and Cython memoryviews. Getitng the length + of both types uses different API. + random_state : uint32_t* + The random state. + """ + cdef intp_t i, j + + # XXX: should this be `i` or `i+1`? for valid Fisher-Yates? + for i in range(0, grid_size - 1): + j = rand_int(i, grid_size, random_state) + indices_to_sample[j], indices_to_sample[i] = \ + indices_to_sample[i], indices_to_sample[j] + + +cdef inline int rand_weighted_binary( + float64_t p0, + uint32_t* random_state +) noexcept nogil: """Sample from integers 0 and 1 with different probabilities. Parameters @@ -54,7 +84,9 @@ cpdef unravel_index( index = np.intp(index) shape = np.array(shape) coords = np.empty(shape.shape[0], dtype=np.intp) - unravel_index_cython(index, shape, coords) + cdef const intp_t[:] shape_memview = shape + cdef intp_t[:] coords_memview = coords + unravel_index_cython(index, shape_memview, coords_memview) return coords @@ -83,7 +115,11 @@ cpdef ravel_multi_index(intp_t[:] coords, const intp_t[:] shape): return ravel_multi_index_cython(coords, shape) -cdef void unravel_index_cython(intp_t index, const intp_t[:] shape, intp_t[:] coords) noexcept nogil: +cdef inline void unravel_index_cython( + intp_t index, + const intp_t[:] shape, + vector_or_memview coords +) noexcept nogil: """Converts a flat index into a tuple of coordinate arrays. Parameters @@ -92,13 +128,9 @@ cdef void unravel_index_cython(intp_t index, const intp_t[:] shape, intp_t[:] co The flat index to be converted. shape : numpy.ndarray[intp_t, ndim=1] The shape of the array into which the flat index should be converted. - coords : numpy.ndarray[intp_t, ndim=1] - A preinitialized memoryview array of coordinate arrays to be converted. - - Returns - ------- - numpy.ndarray[intp_t, ndim=1] - An array of coordinate arrays, with each coordinate array having the same shape as the input `shape`. + coords : intp_t[:] or vector[intp_t] + A preinitialized array of coordinates to store the result of the + unraveled `index`. """ cdef intp_t ndim = shape.shape[0] cdef intp_t j, size @@ -109,13 +141,16 @@ cdef void unravel_index_cython(intp_t index, const intp_t[:] shape, intp_t[:] co index //= size -cdef intp_t ravel_multi_index_cython(intp_t[:] coords, const intp_t[:] shape) noexcept nogil: - """Converts a tuple of coordinate arrays into a flat index. +cdef inline intp_t ravel_multi_index_cython( + vector_or_memview coords, + const intp_t[:] shape +) noexcept nogil: + """Converts a tuple of coordinate arrays into a flat index in the vectorized dimension. Parameters ---------- - coords : numpy.ndarray[intp_t, ndim=1] - An array of coordinate arrays to be converted. + coords : intp_t[:] or vector[intp_t] + An array of coordinates to be converted and vectorized into a sinlg shape : numpy.ndarray[intp_t, ndim=1] The shape of the array into which the coordinates should be converted. diff --git a/treeple/tree/manifold/_morf_splitter.pxd b/treeple/tree/manifold/_morf_splitter.pxd index a0a61a4d..2b65fd3b 100644 --- a/treeple/tree/manifold/_morf_splitter.pxd +++ b/treeple/tree/manifold/_morf_splitter.pxd @@ -32,14 +32,6 @@ cdef class PatchSplitter(BestObliqueSplitter): # an input data vector. The input data is vectorized, so `data_height` and # `data_width` are used to determine the vectorized indices corresponding to # (x,y) coordinates in the original un-vectorized data. - - cdef public intp_t max_patch_height # Maximum height of the patch to sample - cdef public intp_t max_patch_width # Maximum width of the patch to sample - cdef public intp_t min_patch_height # Minimum height of the patch to sample - cdef public intp_t min_patch_width # Minimum width of the patch to sample - cdef public intp_t data_height # Height of the input data - cdef public intp_t data_width # Width of the input data - cdef public intp_t ndim # The number of dimensions of the input data cdef const intp_t[:] data_dims # The dimensions of the input data @@ -56,7 +48,7 @@ cdef class PatchSplitter(BestObliqueSplitter): cdef intp_t[::1] _index_data_buffer cdef intp_t[::1] _index_patch_buffer - cdef intp_t[:] patch_dims_buff # A buffer to store the dimensions of the sampled patch + cdef intp_t[:] patch_sampled_size # A buffer to store the dimensions of the sampled patch cdef intp_t[:] unraveled_patch_point # A buffer to store the unraveled patch point # All oblique splitters (i.e. non-axis aligned splitters) require a diff --git a/treeple/tree/manifold/_morf_splitter.pyx b/treeple/tree/manifold/_morf_splitter.pyx index d6c8d012..f1eaf291 100644 --- a/treeple/tree/manifold/_morf_splitter.pyx +++ b/treeple/tree/manifold/_morf_splitter.pyx @@ -151,7 +151,7 @@ cdef class BestPatchSplitter(BaseDensePatchSplitter): self.data_dims = data_dims # create a buffer for storing the patch dimensions sampled per projection matrix - self.patch_dims_buff = np.zeros(data_dims.shape[0], dtype=np.intp) + self.patch_sampled_size = np.zeros(data_dims.shape[0], dtype=np.intp) self.unraveled_patch_point = np.zeros(data_dims.shape[0], dtype=np.intp) # store the min and max patch dimension constraints @@ -237,7 +237,7 @@ cdef class BestPatchSplitter(BaseDensePatchSplitter): top_left_patch_seed = rand_int(0, delta_patch_dim, random_state) # write to buffer - self.patch_dims_buff[idx] = patch_dim + self.patch_sampled_size[idx] = patch_dim patch_size *= patch_dim elif self.boundary == "wrap": # add circular boundary conditions @@ -251,7 +251,7 @@ cdef class BestPatchSplitter(BaseDensePatchSplitter): # resample the patch dimension due to padding patch_dim = min(patch_dim, min(dim+1, self.data_dims[idx] + patch_dim - dim - 1)) - self.patch_dims_buff[idx] = patch_dim + self.patch_sampled_size[idx] = patch_dim patch_size *= patch_dim # TODO: make this work @@ -283,7 +283,7 @@ cdef class BestPatchSplitter(BaseDensePatchSplitter): cdef intp_t top_left_patch_seed # size of the sampled patch, which is just the size of the n-dim patch - # (\prod_i self.patch_dims_buff[i]) + # (\prod_i self.patch_sampled_size[i]) cdef intp_t patch_size for proj_i in range(0, max_features): @@ -299,7 +299,7 @@ cdef class BestPatchSplitter(BaseDensePatchSplitter): proj_i, patch_size, top_left_patch_seed, - self.patch_dims_buff + self.patch_sampled_size ) cdef void sample_proj_vec( @@ -389,7 +389,7 @@ cdef class BestPatchSplitter(BaseDensePatchSplitter): if not self.dim_contiguous[idx]: row_index += ( (self.unraveled_patch_point[idx] // other_dims_offset) % - self.patch_dims_buff[idx] + self.patch_sampled_size[idx] ) * other_dims_offset other_dims_offset //= self.data_dims[idx] @@ -445,7 +445,7 @@ cdef class BestPatchSplitterTester(BestPatchSplitter): """A class to expose a Python interface for testing.""" cpdef sample_top_left_seed_cpdef(self): top_left_patch_seed, patch_size = self.sample_top_left_seed() - patch_dims = np.array(self.patch_dims_buff, dtype=np.intp) + patch_dims = np.array(self.patch_sampled_size, dtype=np.intp) return top_left_patch_seed, patch_size, patch_dims cpdef sample_projection_vector( From 980de161f12cd60e9be5a17a665b08c2d47e2552 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Sep 2024 15:15:21 -0400 Subject: [PATCH 4/4] [pre-commit.ci] pre-commit autoupdate (#324) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.6.3 → v0.6.4](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.3...v0.6.4) - [github.com/astral-sh/ruff-pre-commit: v0.6.3 → v0.6.4](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.3...v0.6.4) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e744a306..0b96a2de 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,7 @@ repos: # Ruff treeple - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.3 + rev: v0.6.4 hooks: - id: ruff name: ruff treeple @@ -31,7 +31,7 @@ repos: # Ruff tutorials and examples - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.3 + rev: v0.6.4 hooks: - id: ruff name: ruff tutorials and examples