Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into sparse-permutation-…
Browse files Browse the repository at this point in the history
…tests-treeple-stats
  • Loading branch information
ryanhausen committed Sep 23, 2024
2 parents 248ac3b + 980de16 commit 6be17a0
Show file tree
Hide file tree
Showing 22 changed files with 236 additions and 114 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ repos:

# Ruff treeple
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.1
rev: v0.6.4
hooks:
- id: ruff
name: ruff treeple
Expand All @@ -31,7 +31,7 @@ repos:

# Ruff tutorials and examples
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.1
rev: v0.6.4
hooks:
- id: ruff
name: ruff tutorials and examples
Expand Down Expand Up @@ -67,7 +67,7 @@ repos:

# mypy
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.1
rev: v1.11.2
hooks:
- id: mypy
# Avoid the conflict between mne/__init__.py and mne/__init__.pyi by ignoring the former
Expand Down
5 changes: 5 additions & 0 deletions .spin/cmds.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import click
from spin import util
from spin.cmds import meson
from spin.cmds.meson import build_dir_option


def get_git_revision_hash(submodule) -> str:
Expand Down Expand Up @@ -145,14 +146,18 @@ def setup_submodule(forcesubmodule=False):
@click.option(
"--forcesubmodule", is_flag=True, help="Force submodule pull.", envvar="FORCE_SUBMODULE"
)
@build_dir_option
@click.pass_context
def build(
ctx,
*,
meson_args,
jobs=None,
clean=False,
verbose=False,
gcov=False,
quiet=False,
build_dir=None,
forcesubmodule=False,
):
"""Build treeple using submodules.
Expand Down
8 changes: 6 additions & 2 deletions DEVELOPING.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ If you are developing locally, you will need the build dependencies to compile t

pip install -r build_requirements.txt

Additionally, you need to install the latest build of scikit-learn:

pip install --force -r build_sklearn_requirements.txt

Other requirements can be installed as such:

pip install .
Expand Down Expand Up @@ -132,7 +136,7 @@ You can also do the same thing using Meson/Ninja itself. Run the following to bu
export PYTHONPATH=${PWD}/build/lib/python<YOUR_PYTHON_VERSION>/site-packages

# to check installation, you need to be in a different directory
cd docs;
cd docs;
python -c "from treeple import tree"
python -c "import sklearn; print(sklearn.__version__);"

Expand Down Expand Up @@ -188,7 +192,7 @@ GH Actions will build wheels for each Python version and OS. Then the wheels nee
<https://github.com/neurodata/treeple/actions/workflows/build_wheels.yml> will have all the wheels for common OSes built for each Python version.

2. Upload wheels to test PyPi
This is to ensure that the wheels are built correctly and can be installed on a fresh environment. For more information, see <https://packaging.python.org/guides/using-testpypi/>. You will need to follow the instructions to create an account and get your API token for testpypi and pypi.
This is to ensure that the wheels are built correctly and can be installed on a fresh environment. For more information, see <https://packaging.python.org/guides/using-testpypi/>. You will need to follow the instructions to create an account and get your API token for testpypi and pypi.

```
twine upload dist/* --repository testpypi
Expand Down
2 changes: 1 addition & 1 deletion build_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ click
rich-click
doit
pydevtool
spin
spin>=0.12
build
3 changes: 3 additions & 0 deletions build_sklearn_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
--pre
--extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple
scikit-learn
2 changes: 1 addition & 1 deletion meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ project(
license: 'PolyForm Noncommercial 1.0.0',
meson_version: '>= 1.1.0',
default_options: [
'c_std=c99',
'c_std=c11',
'cpp_std=c++14',
],
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ build = [
'twine',
'meson',
'meson-python',
'spin',
'spin>=0.12',
'doit',
'scikit-learn>=1.5.0',
'Cython>=3.0.10',
Expand Down
4 changes: 3 additions & 1 deletion treeple/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# https://github.com/ContinuumIO/anaconda-issues/issues/11294
os.environ.setdefault("KMP_INIT_AT_FORK", "FALSE")


try:
# This variable is injected in the __builtins__ by the build
# process. It is used to enable importing subpackages of sklearn when
Expand Down Expand Up @@ -64,7 +65,8 @@
msg = """Error importing treeple: you cannot import treeple while
being in treeple source directory; please exit the treeple source
tree first and relaunch your Python interpreter."""
raise ImportError(msg) from e
raise Exception(e)
# raise ImportError(msg) from e

__all__ = [
"_lib",
Expand Down
19 changes: 19 additions & 0 deletions treeple/_lib/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,22 @@ foreach ext: extensions
subdir: 'treeple/_lib/sklearn/utils/',
)
endforeach


# python_sources = [
# '__init__.py',
# ]

# py.install_sources(
# python_sources,
# subdir: 'treeple/_lib' # Folder relative to site-packages to install to
# )

# tempita = files('./sklearn/_build_utils/tempita.py')

# # Copy all the .py files to the install dir, rather than using
# # py.install_sources and needing to list them explicitely one by one
# # install_subdir('sklearn', install_dir: py.get_install_dir())
# install_subdir('sklearn', install_dir: join_paths(py.get_install_dir(), 'treeple/_lib'))

# subdir('sklearn')
2 changes: 1 addition & 1 deletion treeple/_lib/sklearn_fork
Submodule sklearn_fork updated 216 files
8 changes: 6 additions & 2 deletions treeple/ensemble/_honest_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,8 +720,12 @@ def oob_samples_(self):
oob_samples.append(_oob_samples)
return oob_samples

def _more_tags(self):
return {"multioutput": False}
def __sklearn_tags__(self):
# XXX: nans should be supportable in HRF
tags = super().__sklearn_tags__()
tags.classifier_tags.multi_output = False
tags.input_tags.allow_nan = False
return tags

def decision_path(self, X):
"""
Expand Down
12 changes: 8 additions & 4 deletions treeple/ensemble/_unsupervised_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
)
from sklearn.metrics import calinski_harabasz_score
from sklearn.utils.parallel import Parallel, delayed
from sklearn.utils.validation import _check_sample_weight, check_is_fitted, check_random_state
from sklearn.utils.validation import (
_check_sample_weight,
check_is_fitted,
check_random_state,
validate_data,
)

from .._lib.sklearn.ensemble._forest import BaseForest
from .._lib.sklearn.tree._tree import DTYPE
Expand Down Expand Up @@ -85,10 +90,9 @@ def fit(self, X, y=None, sample_weight=None):
self : object
Returns the instance itself.
"""
self._validate_params()

# Validate or convert input data
X = self._validate_data(
X = validate_data(
self,
X,
dtype=DTYPE, # accept_sparse="csc",
)
Expand Down
1 change: 1 addition & 0 deletions treeple/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ scikit_learn_cython_args = [
'-X language_level=3', '-X boundscheck=' + boundscheck, '-X wraparound=False',
'-X initializedcheck=False', '-X nonecheck=False', '-X cdivision=True',
'-X profile=False',
'-X embedsignature=True',
# Needed for cython imports across subpackages, e.g. cluster pyx that
# cimports metrics pxd
'--include-dir', meson.global_build_root(),
Expand Down
15 changes: 11 additions & 4 deletions treeple/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from sklearn.base import BaseEstimator, MetaEstimatorMixin
from sklearn.exceptions import NotFittedError
from sklearn.neighbors import NearestNeighbors
from sklearn.utils.validation import check_is_fitted
from sklearn.utils.validation import check_is_fitted, validate_data

from treeple.tree import DecisionTreeClassifier
from treeple.tree._neighbors import _compute_distance_matrix, compute_forest_similarity_matrix


Expand All @@ -31,13 +32,19 @@ class NearestNeighborsMetaEstimator(BaseEstimator, MetaEstimatorMixin):
The number of parallel jobs to run for neighbors, by default None.
"""

def __init__(self, estimator, n_neighbors=5, radius=1.0, algorithm="auto", n_jobs=None):
def __init__(self, estimator=None, n_neighbors=5, radius=1.0, algorithm="auto", n_jobs=None):
self.estimator = estimator
self.n_neighbors = n_neighbors
self.algorithm = algorithm
self.radius = radius
self.n_jobs = n_jobs

def get_estimator(self):
if self.estimator is not None:
return DecisionTreeClassifier(random_state=0)
else:
return copy(self.estimator)

def fit(self, X, y=None):
"""Fit the nearest neighbors estimator from the training dataset.
Expand All @@ -56,9 +63,9 @@ def fit(self, X, y=None):
self : object
Fitted estimator.
"""
X, y = self._validate_data(X, y, accept_sparse="csc")
X, y = validate_data(self, X, y, accept_sparse="csc")

self.estimator_ = copy(self.estimator)
self.estimator_ = self.get_estimator()
try:
check_is_fitted(self.estimator_)
except NotFittedError:
Expand Down
Loading

0 comments on commit 6be17a0

Please sign in to comment.