-
Notifications
You must be signed in to change notification settings - Fork 391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ADABoosting with Skorch #601
Comments
Hi, you don't give much detail, so I'm not totally sure how you're trying to pass sample weights. But please have a look at our docs to see if this solves your problem. If it does, feel free to close the issue. |
@QuantumChamploo Any updates on this? |
Hey, thanks for checking in, been distracted by the world collapsing. Thank you for pointing me to the docs, but I could not get the example to work. Are there examples of this being used? The docs were a bit short on implementation details and could not find a more verbose example As for ADABoosting, I'm a little more familiar with I am not sure if the dictionary implementation will work with ADA. Has ADABoosting with Skorch been implemented before? |
Do you maybe have a code example, ideally something that I could reproduce 100%. It's a bit difficult for me to help you in the abstract. I'm fairly certain this can be made to work with skorch. Stay safe! |
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
from skorch import NeuralNetClassifier
from sklearn.ensemble import AdaBoostClassifier
mnist = fetch_openml('mnist_784', cache=False)
X = mnist.data.astype('float32')
y = mnist.target.astype('int64')
X /= 255.0
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
mnist_dim = X.shape[1]
hidden_dim = int(mnist_dim/8)
output_dim = len(np.unique(mnist.target))
X = {'data': X_train}
X['sample_weight'] = np.ones(len(X_train), dtype=float)
class ClassifierModule(nn.Module):
def __init__(
self,
input_dim=mnist_dim,
hidden_dim=hidden_dim,
output_dim=output_dim,
dropout=0.5,
criterion__reduce=False
):
super(ClassifierModule, self).__init__()
self.dropout = nn.Dropout(dropout)
self.hidden = nn.Linear(input_dim, hidden_dim)
self.output = nn.Linear(hidden_dim, output_dim)
def forward(self, Xdict, **kwargs):
X = Xdict['data']
X = F.relu(self.hidden(X))
X = self.dropout(X)
X = F.softmax(self.output(X), dim=-1)
return X
def get_loss(self, y_pred, y_true, X, *args, **kwargs):
# override get_loss to use the sample_weight from X
loss_unreduced = super().get_loss(y_pred, y_true, X, *args, **kwargs)
sample_weight = X['sample_weight']
loss_reduced = (sample_weight * loss_unreduced).mean()
torch.manual_seed(0)
net = NeuralNetClassifier(
ClassifierModule,
max_epochs=20,
lr=0.1,
device=device
)
abc = AdaBoostClassifier(NeuralNetClassifier(
ClassifierModule,
max_epochs=20,
lr=0.1,
device=device,
),n_estimators=7,learning_rate=.9)
abc.fit(X_train, y_train) |
This is essentially what I would like to get to work. I know that this is not the rightway to implement sample_weights, as I could not make a simple example work, but I thought it would show what i am trying to get down. Making ADA work with skorch would be amazing, and it seems like it could be done |
I got something to work. If it's really correct, I'm not sure, since that would require to dig into the sklearn implementation of adaboost. But at least for me, the model trains and gives reasonable results. Below are just the changed lines of code from your example, the rest is the same: from skorch.dataset import get_len
from skorch.utils import to_device
import warnings
class ClassifierModule(nn.Module):
def __init__(
self,
input_dim=mnist_dim,
hidden_dim=hidden_dim,
output_dim=output_dim,
dropout=0.5,
):
super(ClassifierModule, self).__init__()
self.dropout = nn.Dropout(dropout)
self.hidden = nn.Linear(input_dim, hidden_dim)
self.output = nn.Linear(hidden_dim, output_dim)
def forward(self, X, **kwargs):
X = F.relu(self.hidden(X))
X = self.dropout(X)
X = F.softmax(self.output(X), dim=-1)
return X
class AdaboostMixin:
"""Mixin class to make the neural net compatible with sklearn's AdaBoost"""
def fit(self, X, y=None, sample_weight=None, **fit_params):
if getattr(self, 'criterion__reduction', True) != 'none':
raise ValueError("Please pass argument criterion__reduction='none'")
if sample_weight is None:
raise ValueError("Must pass sample weights")
m, n = len(X), len(sample_weight)
if m != n:
raise ValueError("X and sample_weight have different lengths: {} vs {}".format(m, n))
if isinstance(X, dict) and ('sample_weight' in X):
raise ValueError("X already contains a key called sample_weight")
if not isinstance(X, dict):
warnings.warn("Converting input data to dict")
X = {'X': X}
else:
X = X.copy()
X['sample_weight'] = sample_weight
return super().fit(X, y, **fit_params)
def get_loss(self, y_pred, y_true, X, *args, **kwargs):
# override get_loss to use the sample_weight from X
loss_unreduced = super().get_loss(y_pred, y_true, X, *args, **kwargs)
sample_weight = X['sample_weight']
sample_weight = to_device(sample_weight, self.device)
loss_reduced = (sample_weight * loss_unreduced).mean()
return loss_reduced
class MyNeuralNetClassifier(AdaboostMixin, NeuralNetClassifier):
pass
abc = AdaBoostClassifier(
MyNeuralNetClassifier(
ClassifierModule,
max_epochs=5,
lr=0.1, # <- probably a much bigger learning rate is required
device=device,
criterion__reduction='none',
# train_split=False, # <- consider turining off the train split
),
n_estimators=5,
learning_rate=.9
) With those hyper-parameters, each of the 5 nets gets a valid accuracy of 0.5 and the ensemble achieves 0.79. The main change that was necessary is to add If this works for you as well (or you still have trouble) please report back. Also if you have further insights that this works correctly. If this works as expected, we could consider adding |
Wow thank you so so much. This is a big break through for me and is deeply appreciated. made my week really. I will run this by my PI, and see what tests we need to do to make sure it working the way we think it is, but it seems great. Ill get back to you soon about. I got slightly different base accuracies for the individual nets. Were they all at .5 acc, or just around there? Thanks again |
I really hope this works out for you, good luck. There was some variance between nets: 0.5102, 0.5064, 0.6207, 0.5873, 0.5390 (valid accuracy). For serious training, I would turn off skorch-internal validation though ( |
Yea this is a thing. The mixin does more than "enables adaboost", it makes it easy to enable Quick thought: Isn't def get_loss(self, y_pred, y_true, X, *args, **kwargs):
# override get_loss to use the sample_weight from X
loss_unreduced = super().get_loss(y_pred, y_true, X, *args, **kwargs) (I didn't know there were uses using adaboost with neural networks. Sounds fun tho :D) |
Yes, you're right,
Sorry, I didn't get the point. |
As in criterion = nn.NLLLoss()
y_pred = torch.Tensor(
[[0.3, 0.1, 0.6],
[0.4, 0.5, 0.1],
[0.2, 0.2, 0.6]])
y_true = torch.tensor([1, 0, 2])
criterion(y_pred, y_true)
# tensor(-0.3667) Lets say the |
It shouldn't be because we have |
I agree with @BenjaminBossan -- I think the implementation should be fine due to the lines: loss_reduced = (sample_weight * loss_unreduced).mean()
return loss_reduced where Regarding
there is a valid concern though that I have. So, when I see this correctly -- I am not that familiar with skorch -- it automatically divides the dataset into minibatches for stochastic gradient descent. I am not sure if that's ideal for AdaBoost, because the To make it work, I think the code needs to be modified such that |
Correct, though you could set the batch size to be the size of the dataset (
This point I don't get. The sample weights will be batched together with the training data (in fact, it's treated the same as training data for batching purposes), so if a mini batch starts with a sample from position 1234, the first sample weight will also be the one from position 1234. Or do you mean something else? Btw., I edited my snippet above. I removed the line |
That's good, this would address my concern then :). |
Great. Then let's wait for @QuantumChamploo to validate the results. |
@QuantumChamploo did you have the time to try this yet? |
@BenjaminBossan yes, I have been testing this implementation. Unfortunately I came into a very strange error when trying to use it for analysis, where I was more rigorously testing the ensembles effectiveness. While trying to create a concise reproduction of the error, I figured out a work around, but I would still like to share the error. What I am trying to do is sweep through the learning rate at 3 different values (.1,.01,.001) 3 times each time. The only difference in the two following code snippets are that i flipped the ordering of the two loops while taking these accuracy values """
""" """
""" So in the first screen shot each array represents a cycle through the parameters, while in the second each array is just one value of the hyper parameter. In the second snippet, the trend is not what we expected or have seen when doing individual runs. The values seem much more random, while the first snippet has the clear trend that we expect. This was very strange, and was a big roadblock. It seems I found a work around, but any clue why this would happen? some sort of issue with sharing weights from instantiation to instantiation is my guess. |
Thx for reporting back @QuantumChamploo So if I understand correctly, the values for We had a recent fix (#617) that looks like it could be related. Could you try the current master branch from skorch and see if that fixes your issue? |
@BenjaminBossan sorry if this is a silly question, but to pull the current master branch from skorch, i would not want to reinstall with pip, but pull the master branch to my local python enviroment? |
I don't know exactly how you set up your local environment. Typically, it should work if you However, if you don't want to clone the github repo, pip also has an option to install from github, explained in this SO. So I believe |
Just to make sure that the correct version was installed, could you execute
in the cell above your results, just to make sure that it's indeed the expected version. Jupyter Notebook/Lab is sometimes a bit weird with handling / respecting different virtual environments as I found out in recent debugging projects ... |
Though the version doesn't tell you if the code is from the current master or not. |
@QuantumChamploo if you could make it possible to replicate your problem, I would gladly take a look at it. |
Oh you are right. I just see that the latest version on PyPI is the same version as in the master branch. Usually, recommended practice reg. Python's open source sci-stack is to bump the version number in the master for the dev version. I.e., 0.8.1dev or 0.9.0a0. Both PyTorch and Scikit-learn do this. For ease of debugging etc. I wonder what you think about adopting it as well here? |
Your are correct, it would be nice if we bumped the version after release. We rarely do patch level releases, but probably |
Sure. It's often also only used as a placeholder. I.e., you use 0.8.1dev during dev and then later swap it with 0.8.1, 0.9.0, or 1.0.0 -- whatever is most appropriate |
below are my two pieces of code the fixed
'''
''' |
@QuantumChamploo Thanks for the code, I tested it on the current master branch and this is what I got: >>> # first example
>>> x_dists
[[0.19582857142857144, 0.7357714285714285, 0.868],
[0.1944, 0.7064, 0.8588571428571429],
[0.11297142857142857, 0.8406857142857143, 0.8203428571428572]]
>>> # second example
>>> x_dists
[[0.18571428571428572, 0.23194285714285715, 0.18457142857142858],
[0.8293142857142857, 0.8109142857142857, 0.8523428571428572],
[0.854, 0.8609714285714286, 0.8674285714285714]]
>>> # first example
>>> y_dist
[[0.3450857142857143, 0.8408, 0.8953142857142857],
[0.3110857142857143, 0.8125142857142857, 0.8893714285714286],
[0.24577142857142859, 0.8480571428571428, 0.8936571428571428]]
>>> # second example
>>> y_dist
[[0.42028571428571426, 0.4465142857142857, 0.26794285714285715],
[0.8168, 0.8121714285714285, 0.8471428571428572],
[0.8905714285714286, 0.8922285714285715, 0.8905714285714286]] So I believe this looks correct, right? The second example is roughly the same as the first, only transposed. There is some randomness in the scores, but that's probably within the expected range. Assuming this is indeed correct, the bugfix mentioned above is probably the reason for why it works now and you need to make sure to really work on the current master (with the version bump, it should now be easier to verify, it should be 0.8.1dev). |
@BenjaminBossan yes this looks great! ill bump to the current master and give this a try. It seems the AdaboostMixin class is getting tripped up on the gpu device. The following is code that should reproduce the issue:
|
fingers crossed
Without running your sample code, my guess is that the from skorch.utils import to_device
...
# inside get_loss
sample_weight = X['sample_weight']
sample_weight = to_device(sample_weight, self.device)
... |
that seemed to work. I also found that fixed the issues as well. I wanted to note that i used the following commands into the neuralnet object I initially had the shuffle flags on, but this ruined the net , i.e made it so that it went to ~10 percent which is just random guessing. The neural net was being trained right, in the verbose output you could see it was training properly, but the output for test accuracy was diminished. My guess is something got shuffled improperly. Its not a big deal, because I got this working but thought I would mention it |
Yes, when you shuffle the validation data, your predictions will be shuffled, and thus are no longer aligned with your targets, which, as you observed, results in random level scores. |
I think this issue has been solved. We could consider adding |
Hi everyone, this was very useful for me and I tried to apply the same concept to GradientBoosting, while getting this error: base_estimator = NeuralNetClassifier(ClassifierModule, abc = GradientBoostingClassifier(base_estimator, n_estimators=7,learning_rate=.9) abc.fit(X_train, y_train) /opt/conda/lib/python3.9/site-packages/sklearn/ensemble/_gb.py in _check_params(self) Anyone has any idea on how to make NeuralNetClassifier compatible. I don't understand while the loss parameter of GB is not compatible. |
I would like to use skorch to implement Sklearn's ADABoosting methods on torch based models. When implemented (naively), an error saying "NeuralNetClassifier doesn't support sample_weight". I kind of expected this, but is there a way to extend the models to support sample weights?
I apologize if this is an inappropriate question or if has been answered, but I could not find anyone else trying this
The text was updated successfully, but these errors were encountered: