Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Estimator wrapper #18

Merged
merged 14 commits into from
May 16, 2024
Merged

Estimator wrapper #18

merged 14 commits into from
May 16, 2024

Conversation

aazuspan
Copy link
Contributor

@aazuspan aazuspan commented May 6, 2024

@grovduck, I've been staring at this too long and feel like I'm getting to the point where there's probably some unnecessary complexity left over, but I'm going to start making it worse if I keep working on it! A second set of eyes would be a huge help when you have a chance!

Closes #13 and #14 by switching from a functional API for sklearn methods (currently predict and kneighbors) to a wrapper estimator. Now instead of calling those functions directly, you can wrap an existing estimator, fit it, and call the estimator's methods with image data, e.g.

from sknnr import GNNRegressor

from sknnr_spatial.datasets import load_swo_ecoplot
from sknnr_spatial import wrap


X_img, X, y = load_swo_ecoplot(as_dataset=True)
est = wrap(GNNRegressor(n_neighbors=7)).fit(X, y)

est.predict(X_img).PSME_COV.plot()

Because estimators must be fit after wrapping, we can now force consistent outputs with single-output data (#14) by 1) always squeezing single-output to 1D since that produces consistent results with all estimators and 2) storing the number and name of targets used for fitting. This avoids the need to pass y data during prediction or try to infer things from the estimator.

Some notes below:

Naming

The wrap function name is 100% a placeholder. Since this function is pretty much the entire public API, it seems like we should give it a good name!

We could also skip the wrapper function and have users instantiate the ImageEstimator directly (also open to naming suggestions there), but I thought that might feel more like you're getting a new estimator type with different methods rather than just a wrapper around your estimator.

Non-image data

I added a check_is_x_image decorator that falls back to the wrapped sklearn methods for non-image data, so e.g. you can still predict from tabular data with a wrapped estimator. I'm not totally sure that's the right decision, so let me know if you have any thoughts on that.

Image wrappers instead of single dispatch

The single dispatch system for dispatching sklearn methods to the correct image types was causing some unnecessary duplication in the Dask-backed images that was only going to get worse as we add more sklearn methods (like transform). I switched that for an "image wrapper" system that encapsulates the preprocessor and sklearn methods for each image type. I thought about just combining that directly with the preprocessor classes, but I figured they serve somewhat different purposes so it might be better to keep them separate.

Once again, there's probably a better name than ImageWrapper for that class. I feel like I'm currently using "wrapper" in a lot of different ways that might be making things unnecessarily confusing, e.g. the ImageWrapper that contains additional functionality for an image, the EstimatorWrapper that allows access to the underlying estimator, and the wrap_image and unwrap_image in the test system that convert between types.

Incompatible estimator methods

I tried a handful of different approaches to overriding just the relevant methods for a wrapped estimator (e.g. don't implement kneighbors for a RandomForestRegressor) and never came up with a perfect solution. Programatically assigning methods with setattr looked promising, but I never found a way to get that working with Intellisense, which wasn't able to trace which methods got assigned where. I also tried implementing all methods and removing unused ones, but it doesn't seem possible to dynamically remove a method from one instance of a class without removing it from the class entirely.

The solution I settled on was the check_wrapper_implements decorator that dynamically throws a NotImplementedError if you try to call an unsupported method. Unfortunately this means unsupported methods are still visible and documented, but that seemed preferable to the alternative of not documenting supported methods.

Feature name warnings

No changes here - warnings are still present when fit with feature names.

aazuspan added 13 commits May 2, 2024 17:42
The single dispatching system became a little convoluted with the
introduction of the estimator wrapper, and required duplication of
things like the Dataset and DataArray methods. I switched to an
inheritance system where each image type has an ImageWrapper subclass
that 1) defines the preprocessor for the image type and 2) implements
the sklearn methods.
@aazuspan aazuspan added the enhancement New feature or request label May 6, 2024
@aazuspan aazuspan requested a review from grovduck May 6, 2024 18:47
@aazuspan aazuspan self-assigned this May 6, 2024
@aazuspan aazuspan linked an issue May 6, 2024 that may be closed by this pull request
@aazuspan aazuspan changed the title Wrapper Estimator wrapper May 6, 2024
@grovduck
Copy link
Member

grovduck commented May 7, 2024

@aazuspan, I haven't yet dug into the code, but I wanted to expand on running the sample code from yesterday. I gave you a bit of a red herring yesterday in saying that the issue was running the code in VSCode vs. JupyterLab.

When I had run it in JupyterLab, it was using a LocalCluster that I had been using on a different notebook. After I shut down that cluster and run the sample code again in JupyterLab, I am hanging on est.predict(X_img).PSME_COV.plot(). This would suggest that my thread-based dask configuration (e.g. without a LocalCluster) is not behaving correctly.

Can you detail how/where you are running the sample code and perhaps relevant package versions? It's very possible that it's something in my setup only, but want to make sure this wouldn't trip up others. I'll continue to experiment as well. Also, if you have timings for thread-based vs local cluster, I think that might be helpful. I'm getting 2.22 s for the local cluster with 36 cores, 128GB RAM.

@aazuspan
Copy link
Contributor Author

aazuspan commented May 7, 2024

Interesting, seems like a thread-based issue might tie into the synthetic plot performance we were talking about as well...

Here's a slightly more repeatable test setup that I'm running in a Jupyter notebook in VS Code:

import dask
from sknnr import GNNRegressor

from sknnr_spatial import wrap
from sknnr_spatial.datasets import load_swo_ecoplot

dask.config.set(scheduler="threads")

X_img, X, y = load_swo_ecoplot(as_dataset=True)
est = wrap(GNNRegressor(n_neighbors=7)).fit(X, y)
pred = est.predict(X_img)

%timeit pred.compute()

Result:

166 ms ± 1.84 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

And using a default Client() as the scheduler instead:

1.28 s ± 61.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

My kernel is the hatch default environment with:

python: 3.10.12
dask: 2024.3.1
sknnr: 0.1.0-alpha
xarray: 2024.2.0

My PC has 64GB ram and the local cluster has 4 workers with 2 threads each.

@grovduck
Copy link
Member

grovduck commented May 7, 2024

Thanks, @aazuspan. Part (or all) of the issue is my stupidity or lack of awareness. I definitely was not specifying:

dask.config.set(scheduler="threads")

When I do this, I'm getting similar timings to you:

  • dask.config.set(scheduler="threads"): 242 ms ± 13.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
  • Client(): 955 ms ± 450 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

But what I still don't totally understand is that the threads scheduler is supposed to be the default without specifying anything else. From the dask deploying page:

Local Machine
You can run Dask without any setup. Dask will use threads on your local machine by default.

import dask.dataframe as dd
df = dd.read_csv(...)
df.x.sum().compute()  # This uses threads on your local machine

Is this not the case?

EDIT: I think I have something weird going on with my main virtualenv (not the hatch env) that I've been testing with. On this specific virtualenv, that test code is not finishing (in screenshot below running for 42s and still running) . This is definitely an issue with me and I'll continue to try to sort it out. It is working when using the hatch environment associated with sknnr-spatial.

We are tracking this issue in #19 now.

image

Copy link
Member

@grovduck grovduck left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aazuspan, awesome work yet again. I really like the estimator wrapper API and can see how it's going to make other methods (e.g. transform) much easier to deal with. Code-wise, I'm not nearly clever enough to suggest too many changes - most of my edits are documentation or type-based. Pending those, I think it looks really good!

(Just "commenting" for now, but happy to change to an approval.)

src/sknnr_spatial/utils/estimator.py Show resolved Hide resolved
src/sknnr_spatial/image/_base.py Show resolved Hide resolved
src/sknnr_spatial/image/_base.py Outdated Show resolved Hide resolved
src/sknnr_spatial/estimator.py Show resolved Hide resolved
src/sknnr_spatial/utils/estimator.py Outdated Show resolved Hide resolved
src/sknnr_spatial/estimator.py Show resolved Hide resolved
src/sknnr_spatial/utils/image.py Show resolved Hide resolved
Better naming and type annotations, remove redundant fitted checks,
test for expected NotFittedErrors.
@aazuspan
Copy link
Contributor Author

Thanks for the review @grovduck! I think everything's resolved, but if you have a chance to take a final look before I merge that would be appreciated.

@grovduck
Copy link
Member

All looks great, @aazuspan. I noticed that you added a new test and caught a variable naming issue as well (x_image to X_image). LGTM!

@aazuspan aazuspan merged commit fc1bdbf into main May 16, 2024
5 checks passed
@aazuspan aazuspan deleted the wrapper branch May 16, 2024 01:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Handle single-output estimators Implement estimator wrapper
2 participants