Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support spatial transform in estimators that inherit from TransformerMixin #16

Open
grovduck opened this issue May 1, 2024 · 6 comments
Labels
enhancement New feature or request

Comments

@grovduck
Copy link
Member

grovduck commented May 1, 2024

For some use cases, it will be valuable to return the spatial representation of an estimator's transform method where that estimator inherits from sklearn's TransformerMixin (for example, sklearn's PCA or sknnr's CCATransformer).

As of #12, sknnr-spatial supports predict and kneighbors in a functional context (soon to be implemented as an estimator wrapper in #13). Supporting transform will likely be an extension of this logic. Based on some initial crude experimentation, the following code implements transform using a functional API. This code is not fully tested, but introducing here to keep a record of what was done.

src/sknnr-spatial/image/_base.py

@singledispatch
def transform(
    X_image: NDArray | xr.DataArray | xr.Dataset,
    *,
    estimator: BaseEstimator,
    nodata_vals=None,
) -> None:
    msg = f"transform is not implemented for type `{X_image.__class__.__name__}`."
    raise NotImplementedError(msg)

src/sknnr-spatial/image/ndarray.py

@transform.register(np.ndarray)
def _transform_from_ndarray(
    X_image: NDArray,
    *,
    estimator: KNeighborsRegressor,
    nodata_vals=None,
    **kneighbors_kwargs,
) -> NDArray:
    check_is_fitted(estimator)
    preprocessor = NDArrayPreprocessor(X_image, nodata_vals=nodata_vals)

    # TODO: Deal with sklearn warning about missing feature names
    y_pred_flat = estimator.transform(preprocessor.flat)

    return preprocessor.unflatten(y_pred_flat, apply_mask=True)

src/sknnr-spatial/image/_dask_backed.py

def transform_from_dask_backed_array(
    X_image: DaskBackedType,
    *,
    estimator: BaseEstimator,
    y=None,
    preprocessor_cls: type[DataArrayPreprocessor] | type[DatasetPreprocessor],
    nodata_vals=None,
) -> DaskBackedType:
    """Generic transform wrapper for Dask-backed arrays."""
    check_is_fitted(estimator)
    preprocessor = preprocessor_cls(X_image, nodata_vals=nodata_vals)

    # HACK: Using get_features_names_out() to infer the number of targets
    # I don't think this is guaranteed to work for all transformers
    target_names = estimator.get_feature_names_out()
    n_targets = len(target_names)

    y_transform = da.apply_gufunc(
        estimator.transform,
        "(x)->(y)",
        preprocessor.flat,
        axis=preprocessor.flat_band_dim,
        output_dtypes=[float],
        output_sizes={"y": n_targets},
        allow_rechunk=True,
    )

    return preprocessor.unflatten(y_transform, var_names=target_names)

src/sknnr-spatial/image/dataarray.py

@transform.register(xr.DataArray)
def _transform_from_dataarray(
    X_image: xr.DataArray, *, estimator: BaseEstimator, y=None, nodata_vals=None
) -> xr.DataArray:
    return transform_from_dask_backed_array(
        X_image,
        estimator=estimator,
        y=y,
        nodata_vals=nodata_vals,
        preprocessor_cls=DataArrayPreprocessor,
    )

src/sknnr-spatial/image/dataset.py

@transform.register(xr.Dataset)
def _transform_from_dataset(
    X_image: xr.Dataset, *, estimator: BaseEstimator, y=None, nodata_vals=None
) -> xr.Dataset:
    return transform_from_dask_backed_array(
        X_image,
        estimator=estimator,
        y=y,
        nodata_vals=nodata_vals,
        preprocessor_cls=DatasetPreprocessor,
    )
@grovduck grovduck added the enhancement New feature or request label May 1, 2024
@grovduck
Copy link
Member Author

grovduck commented May 1, 2024

@aazuspan, please feel free to modify the issue description. I just wanted to capture what I had done to make synthetic kNN work with the transform step. I'll stash these changes locally, so I can test your last changes that you've made in #12.

@aazuspan
Copy link
Contributor

aazuspan commented May 1, 2024

Looks great @grovduck, thanks!

I just did a quick check and it looks like nearly all transformers implement get_feature_names_out. Maybe we just throw a NotImplementedError for any that don't? For reference:

from sklearn.utils.discovery import all_estimators

for name, trans in all_estimators("transformer"):
    if not hasattr(trans, "get_feature_names_out"):
        print(name)

"""
FeatureHasher
HashingVectorizer
LabelBinarizer
LabelEncoder
MultiLabelBinarizer
PatchExtractor
"""

At a glance, none of these seem relevant to applying in 2D space.

@grovduck
Copy link
Member Author

grovduck commented May 1, 2024

Maybe we just throw a NotImplementedError for any that don't?

That definitely seems like a reasonable approach to me - those transformers don't strike me as any that would be necessary. The flip side would be whether those estimators/transformers that do implement get_feature_names_out make sense to transform into 2D space? Looks like there are 84 of these (not counting sknnr transformers).

@aazuspan
Copy link
Contributor

aazuspan commented May 3, 2024

The flip side would be whether those estimators/transformers that do implement get_feature_names_out make sense to transform into 2D space? Looks like there are 84 of these (not counting sknnr transformers).

Yeah, most of those 84 probably would never be used, but I suppose if they follow a consistent protocol and we can implement them all in one go, there's no harm (maybe with a disclaimer that we only explicitly test/support a subset of them?). One thing we'll need to watch out for is any estimator/transformer that modifies the spatial shape, i.e. that returns more or fewer samples than it was fit with, since that would break the Dask side. I'm not aware of anything like that, but I've probably never touched 90% of the functionality in sklearn, so I wouldn't be shocked.

I'm pretty excited for this feature - being able to run PCA or StandardScaler on images seamlessly will be great!

@grovduck
Copy link
Member Author

grovduck commented May 6, 2024

One thing we'll need to watch out for is any estimator/transformer that modifies the spatial shape, i.e. that returns more or fewer samples than it was fit with, since that would break the Dask side. I'm not aware of anything like that, but I've probably never touched 90% of the functionality in sklearn, so I wouldn't be shocked.

Good call. I recognized only a handful of those transformers as well, so I'm unfamiliar if any would modify the shape, but definitely something to watch out for.

I'm pretty excited for this feature - being able to run PCA or StandardScaler on images seamlessly will be great!

Absolutely!

Just to be clear, your plan is to tackle #13 before handling this one, correct? Please let me know if there are "side jobs" on either of these issues that you'd like my help with (other than reviews).

@aazuspan
Copy link
Contributor

aazuspan commented May 6, 2024

Just to be clear, your plan is to tackle #13 before handling this one, correct? Please let me know if there are "side jobs" on either of these issues that you'd like my help with (other than reviews).

Yes, I made some changes in #18 that should hopefully reduce some code duplication when adding new methods like this (although that means your implementation will need to be refactored a little bit to match, sorry!). If you want to tackle this issue once #18 is merged, that would be great! You've got a better idea of the real-world use cases for this, and you've already figured out the tricky part of getting the output shape.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants