Skip to content

Commit

Permalink
Change MRO so that sklearn tags are set correctly
Browse files Browse the repository at this point in the history
This should fix documentation errors.
  • Loading branch information
vnmabus committed Feb 3, 2025
1 parent 6183d26 commit dd373cd
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 12 deletions.
4 changes: 2 additions & 2 deletions skfda/_utils/_neighbors_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,8 @@ def radius_neighbors_graph(


class NeighborsClassifierMixin(
NeighborsBase[Input, TargetClassification],
ClassifierMixin[Input, TargetClassification],
NeighborsBase[Input, TargetClassification],
):
"""Mixin class for classifiers based in nearest neighbors."""

Expand Down Expand Up @@ -572,8 +572,8 @@ def fit(


class NeighborsRegressorMixin(
NeighborsBase[Input, TargetRegression],
RegressorMixin[Input, TargetRegression],
NeighborsBase[Input, TargetRegression],
):
"""Mixin class for the regressors based on neighbors."""

Expand Down
36 changes: 28 additions & 8 deletions skfda/_utils/_sklearn_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class BaseEstimator( # noqa: D101

class TransformerMixin( # noqa: D101
ABC,
Generic[Input, Output, Target],
sklearn.base.TransformerMixin, # type: ignore[misc]
Generic[Input, Output, Target],
):

@overload
Expand All @@ -52,6 +52,10 @@ def fit( # noqa: D102
X: Input,
y: Target | None = None,
) -> SelfType:
fit = getattr(super(), "fit", None)
if fit:
return super().fit(X, y)

return self

@overload
Expand All @@ -75,6 +79,10 @@ def fit_transform( # noqa: D102
y: Target | None = None,
**fit_params: Any,
) -> Output:
fit_transform = getattr(super(), "fit_transform", None)
if fit_transform:
return fit_transform(X, y, **fit_params)

if y is None:
return self.fit( # type: ignore[no-any-return]
X,
Expand All @@ -97,41 +105,49 @@ def transform( # noqa: D102
self: SelfType,
X: Input,
) -> Output:
pass
return super().transform(X)


class OutlierMixin( # noqa: D101
ABC,
Generic[Input],
sklearn.base.OutlierMixin, # type: ignore[misc]
Generic[Input],
):

def fit_predict( # noqa: D102
self,
X: Input,
y: object = None,
) -> NDArrayInt:
fit_predict = getattr(super(), "fit_predict", None)
if fit_predict:
return fit_predict(X, y)

return self.fit(X, y).predict(X) # type: ignore[no-any-return]


class ClassifierMixin( # noqa: D101
ABC,
Generic[Input, TargetPrediction],
sklearn.base.ClassifierMixin, # type: ignore[misc]
Generic[Input, TargetPrediction],
):
def fit( # noqa: D102
self: SelfType,
X: Input,
y: TargetPrediction,
) -> SelfType:
fit = getattr(super(), "fit", None)
if fit:
return super().fit(X, y)

return self

@abstractmethod
def predict( # noqa: D102
self: SelfType,
X: Input,
) -> TargetPrediction:
pass
return super().predict(X)

def score( # noqa: D102
self,
Expand All @@ -148,8 +164,8 @@ def score( # noqa: D102

class ClusterMixin( # noqa: D101
ABC,
Generic[Input],
sklearn.base.ClusterMixin, # type: ignore[misc]
Generic[Input],
):
def fit_predict( # noqa: D102
self,
Expand All @@ -161,22 +177,26 @@ def fit_predict( # noqa: D102

class RegressorMixin( # noqa: D101
ABC,
Generic[Input, TargetPrediction],
sklearn.base.RegressorMixin, # type: ignore[misc]
Generic[Input, TargetPrediction],
):
def fit( # noqa: D102
self: SelfType,
X: Input,
y: TargetPrediction,
) -> SelfType:
fit = getattr(super(), "fit", None)
if fit:
return super().fit(X, y)

return self

@abstractmethod
def predict( # noqa: D102
self: SelfType,
X: Input,
) -> TargetPrediction:
pass
return super().predict(X)

def score( # noqa: D102
self,
Expand Down
4 changes: 2 additions & 2 deletions skfda/ml/classification/_neighbors_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@


class KNeighborsClassifier(
KNeighborsMixin[Input, NDArrayInt],
NeighborsClassifierMixin[Input, NDArrayInt],
KNeighborsMixin[Input, NDArrayInt],
):
"""
Classifier implementing the k-nearest neighbors vote.
Expand Down Expand Up @@ -190,8 +190,8 @@ def _init_estimator(self) -> _KNeighborsClassifier:


class RadiusNeighborsClassifier(
RadiusNeighborsMixin[Input, NDArrayInt],
NeighborsClassifierMixin[Input, NDArrayInt],
RadiusNeighborsMixin[Input, NDArrayInt],
):
"""
Classifier implementing a vote among neighbors within a given radius.
Expand Down

0 comments on commit dd373cd

Please sign in to comment.