Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incremental learning on large chunks with multiinput Keras models is slow #764

Closed
2over12 opened this issue Dec 1, 2020 · 19 comments
Closed

Comments

@2over12
Copy link

2over12 commented Dec 1, 2020

The recommended way to integrate keras models into dask ML will be slow if the keras model has multiple inputs. The way to handle multiple inputs with scikeras is data transformers that are called before the model is invoked described here: https://github.com/adriangb/scikeras/blob/master/notebooks/DataTransformers.ipynb. For multiple inputs this will involve splitting up the chunk into N arrays where N is the number of inputs. The result of this is that the fit task in incremental will include a particularly slow data split that occurs in serial. This slow fit causes chunks to end up being read far faster than the model can fit them eventually spilling most of the array back to disk.

Ideally the split would occur in the non serial steps instead of being delayed until the fit task. The fit function which is eventually called by incremental here https://github.com/dask/dask-ml/blob/master/dask_ml/_partial.py is built for single arrays and has no injection point. If the user could inject a transformation step into the graph right before partial_fit is called then each chunk could be split in parallel and be ready to go for the actual fitting step. Injecting a transformation step would allow the pipeline to prepare chunks so that the only part that is done in serial is the actual training.

@stsievert
Copy link
Member

cc @adriangb

@adriangb
Copy link

adriangb commented Dec 1, 2020

Thank you for the issue.

Let's see if I understand this correctly (please correct me @stsievert, I'm fairly certain I'm wrong).

Let's say we have some data, X that corresponds to two inputs to a model:

X = np.random.rand((100, 2))

Dask splits the data element-wise into chunks to be trained in parallel. Let's say it splits it into 4 chunks of size 25:

blocks = np.split(X, 5, axis=1)
tasks.schedule(model, "partial_fit", block)  # or whatever this looks like internally

Then within each partial_fit call (which it sounds like happen sequentially?) np.split(X, 2, axis=1) is called (as the "transformer").

Whether that is right or not, what I'm not understanding is how the transformer/column split slows things down. At that point, there is data being passed to the model for training. How is np.split(X, 2, axis=1) any slower than the actual model training?

@2over12
Copy link
Author

2over12 commented Dec 1, 2020

Right so this is essentially it. Partial_fit happens sequentially so the data transformers in scikeras are called sequentially. In the background dask will continue to load batches from the array that will get fed to partial_fit as it is ready.

Our chunks are 4,000,000x2 so our data transformer of [arr[:,0],arr[:,1] is going to take a while. It competes with training time since the model is fairly light. Ideally instead of having the batches just sitting there read, data transformers could be preformed before fitting so that as dask is preloading batches they are completely ready for training. Like I said this is a narrow issue where you have a keras model that needs a relatively expensive data transformation and we have big chunks (so a lot of memory to throw around).

Essentially the idea is since data transformers should not really need to be processed in serial perhaps they can somehow be lifted out of the intentionally serial training step.

@adriangb
Copy link

adriangb commented Dec 1, 2020

I see. I find it interesting but an array split competes with the model training time, but I dunno.

It sounds to me like what you need would be some thing on the Dask side right? SciKeras can be configured to accept data and not process it at all (might have to override _validate_data or something, but that'd be minor). Or is there something we could do in SciKeras to make things faster for you that I'm not seeing?

I guess another question is, did you run cprofile or something to make sure that the slow part is the data splitting? It might be some other part of SciKeras that is slow. I haven't benchmarked with data of anywhere near that size.

@2over12
Copy link
Author

2over12 commented Dec 1, 2020

Yeah so this would be on the dask side if anything. I'm looking to inject a transformation into the graph created by fit defined here https://github.com/dask/dask-ml/blob/master/dask_ml/_partial.py. Since there is no way to execute user code between that graph creation (which can only be done on a single array) and the serial step there is not a way to preform the split in parallel.

Yeah you are also correct that I may have jumped the gun on what the slow step is before training starts. Something hangs for about 50 seconds before batches actually get fed to the model on each chunk I assumed it was the data transformer and not something in scikeras. I figured the scikeras wrapper is fairly thin, but I'll do some profiling to double check.

@adriangb
Copy link

adriangb commented Dec 1, 2020

Sounds good. Let me know if you find anything interesting, or if I can help with that profiling in any way. An easy way could be to inject a timestamp at the start of BaseWrapper.fit, a timestamp at the start and end of your transformer and maybe a timestamp right before BaseWrapper.model_.fit gets called.

@2over12
Copy link
Author

2over12 commented Dec 1, 2020

Following up on this, it can likely be closed. It is still an interesting idea if people do have more heavyweight transforms that require splitting into a non array form but this is not the cause of the slowness. The cause of the slowness is actually _validate_sample_weight in scikeras. Passing a sample weight of all ones [1,1,....] instead of None to keras fit causes a slow conversion step in tensorflow (im not super sure why i think weights may cause each sample to be handled seperately). validate_sample_weight transforms sample weights of None to an array of 1s

@adriangb
Copy link

adriangb commented Dec 1, 2020

BaseWrapper._validate_sample_weight should not be called if you pass sample_weight=None (or if you use class_weight). Are you using one of those two?

@2over12
Copy link
Author

2over12 commented Dec 1, 2020

Both sample_weight and class_weight are None but it still gets called from the KerasClassifier wrapper. Looking at the code I dont see why it wouldn't get called from partial fit regardless since it gets called in _fit

@adriangb
Copy link

adriangb commented Dec 1, 2020

Is this the call you're talking about? it's protected against None:

https://github.com/adriangb/scikeras/blob/45da4ad378968e28cbd608d4af10d89fb65f7e21/scikeras/wrappers.py#L832-L833

@2over12
Copy link
Author

2over12 commented Dec 1, 2020

Ah got it looks like the last tagged release 0.2.0 doesnt have that check

@adriangb
Copy link

adriangb commented Dec 1, 2020

Interesting! Well I'm glad we may have accidentally fixed this for you 😆. I'll probably push out 0.2.1 in the next couple of days. Can you test off of master and see if the slowdown is gone?

@stsievert
Copy link
Member

if people do have more heavyweight transforms that require splitting into a non array form but this is not the cause of the slowness.

If people did have "heavyweight transforms", could that be addressed be doing the transformations beforehand and using BaseWrapper instead of KerasClassifier or KerasRegressor? By default, BaseWrapper doesn't transform the input (source: wrappers.py#L698).

If so, maybe that should be a documentation note for SciKeras.

@adriangb
Copy link

adriangb commented Dec 1, 2020

I think, generally, there are two types of transformations:

  • Transformations from array -> array. This includes one-hot-encoding, label encoding, scaling, etc.
  • Transformations from array -> list, dict, etc. These are for multi input-output.

The former can be "heavy" processing, but since it's array -> array it should be able to be done in dask preprocessing, in parallel. The latter, I think in all cases, should be pretty light transformations.

If I understood correctly @stsievert, you are suggesting something like:

If you are willing to do all of the pre-processing on your data (i.e. matching the encoding with the loss function, etc.), you can skip SciKeras pre-processing by using overriding `target_encoder` and `feature_encoder`.

?

@2over12
Copy link
Author

2over12 commented Dec 1, 2020

I agree with @adriangb as long as the transformation are array->array it can all be done in dask and will work fine. The only issue here would be if there is some expensive transformation that can only be done as array->list or array to dict, which seems unlikely.

Interesting! Well I'm glad we may have accidentally fixed this for you laughing. I'll probably push out 0.2.1 in the next couple of days. Can you test off of master and see if the slowdown is gone?

Yeah master works fine

@stsievert
Copy link
Member

If I understood correctly @stsievert, you are suggesting something like: "If you are willing to do all of the pre-processing on your data (i.e. matching the encoding with the loss function, etc.), you can skip SciKeras pre-processing by using overriding target_encoder and feature_encoder."

Almost. This is more what I'm thinking:

If you are willing to do all of the pre-processing on your data (i.e. matching the encoding with the loss function, etc), you can skip SciKeras pre-processing by using BaseWrapper instead of KerasRegressor and KerasClassifier. By default, BaseWrapper's transformers are minimal and pass the input through – the only overhead they add is the time required for a Python function call.

@stsievert
Copy link
Member

@2over12 it sounds like this issue is resolved #764 (comment). Is that correct, or is there anything else to do?

On Gitter, you mentioned a separate issue on memory usage of wrappers.Incremental. I'm curious to see an issue around that – I think that'd help improve Dask-ML. Thanks!

@adriangb
Copy link

adriangb commented Dec 6, 2020

@2over12 I released SciKeras v0.2.1 in case you want to switch back to a pypi/tagged release instead of master.

@2over12
Copy link
Author

2over12 commented Dec 7, 2020

@2over12 it sounds like this issue is resolved #764 (comment). Is that correct, or is there anything else to do?

On Gitter, you mentioned a separate issue on memory usage of wrappers.Incremental. I'm curious to see an issue around that – I think that'd help improve Dask-ML. Thanks!

Yes that comment and the new release has fixed this issue and as such I am closing it for now. If somebody does end up needing expensive array->dictionary transforms perhaps this can be explored later.

I have opened a new issue regarding the stuff discussed in gitter here: #765 (comment).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants