Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for new __sklearn_tags__ #205

Merged
merged 4 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,16 @@ jobs:
# as well as selected previous versions on
# https://pytorch.org/get-started/previous-versions/
torch-version: ["2.2.2", "2.4.0"]
sklearn-version: ["latest"]
include:
- os: windows-latest
torch-version: 2.4.0
python-version: "3.10"
sklearn-version: "latest"
- os: ubuntu-latest
torch-version: 2.4.0
python-version: "3.10"
sklearn-version: "legacy"

runs-on: ${{ matrix.os }}

Expand All @@ -32,7 +38,7 @@ jobs:
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}
key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}-sklearn_${{ matrix.sklearn-version }}

- name: Checkout code
uses: actions/checkout@v2
Expand All @@ -48,6 +54,11 @@ jobs:
python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu
pip install '.[dev,datasets,integrations]'

- name: Check sklearn legacy version
if: matrix.sklearn-version == 'legacy'
run: |
pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]'

- name: Run the formatter
run: |
make format
Expand Down
18 changes: 17 additions & 1 deletion cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
import pkg_resources
import sklearn.utils.validation as sklearn_utils_validation
import torch
import sklearn
from sklearn.base import BaseEstimator
from sklearn.base import TransformerMixin
from sklearn.utils.metaestimators import available_if
from torch import nn

import cebra.data
Expand All @@ -41,6 +43,11 @@
import cebra.models
import cebra.solver

def check_version(estimator):
# NOTE(stes): required as a check for the old way of specifying tags
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
from packaging import version
return version.parse(sklearn.__version__) < version.parse("1.6.dev")

def _init_loader(
is_cont: bool,
Expand Down Expand Up @@ -364,7 +371,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":
return cebra_


class CEBRA(BaseEstimator, TransformerMixin):
class CEBRA(TransformerMixin, BaseEstimator):
"""CEBRA model defined as part of a ``scikit-learn``-like API.

Attributes:
Expand Down Expand Up @@ -1294,6 +1301,15 @@ def fit_transform(
callback_frequency=callback_frequency)
return self.transform(X)

def __sklearn_tags__(self):
# NOTE(stes): from 1.6.dev, this is the new way to specify tags
# https://scikit-learn.org/dev/developers/develop.html
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
tags = super().__sklearn_tags__()
tags.non_deterministic = True
return tags

@available_if(check_version)
def _more_tags(self):
# NOTE(stes): This tag is needed as seeding is not fully implemented in the
# current version of CEBRA.
Expand Down
Loading