Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADABoosting with Skorch #601

Closed
QuantumChamploo opened this issue Feb 27, 2020 · 37 comments
Closed

ADABoosting with Skorch #601

QuantumChamploo opened this issue Feb 27, 2020 · 37 comments

Comments

@QuantumChamploo
Copy link

I would like to use skorch to implement Sklearn's ADABoosting methods on torch based models. When implemented (naively), an error saying "NeuralNetClassifier doesn't support sample_weight". I kind of expected this, but is there a way to extend the models to support sample weights?

I apologize if this is an inappropriate question or if has been answered, but I could not find anyone else trying this

@BenjaminBossan
Copy link
Collaborator

Hi, you don't give much detail, so I'm not totally sure how you're trying to pass sample weights. But please have a look at our docs to see if this solves your problem. If it does, feel free to close the issue.

@BenjaminBossan
Copy link
Collaborator

@QuantumChamploo Any updates on this?

@QuantumChamploo
Copy link
Author

Hey, thanks for checking in, been distracted by the world collapsing. Thank you for pointing me to the docs, but I could not get the example to work. Are there examples of this being used? The docs were a bit short on implementation details and could not find a more verbose example

As for ADABoosting, I'm a little more familiar with I am not sure if the dictionary implementation will work with ADA. Has ADABoosting with Skorch been implemented before?

@BenjaminBossan
Copy link
Collaborator

Do you maybe have a code example, ideally something that I could reproduce 100%. It's a bit difficult for me to help you in the abstract. I'm fairly certain this can be made to work with skorch.

Stay safe!

@QuantumChamploo
Copy link
Author

QuantumChamploo commented Mar 20, 2020

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F

from skorch import NeuralNetClassifier

from sklearn.ensemble import AdaBoostClassifier


mnist = fetch_openml('mnist_784', cache=False)

X = mnist.data.astype('float32')
y = mnist.target.astype('int64')

X /= 255.0

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'


mnist_dim = X.shape[1]
hidden_dim = int(mnist_dim/8)
output_dim = len(np.unique(mnist.target))


X = {'data': X_train}
X['sample_weight'] = np.ones(len(X_train), dtype=float)

class ClassifierModule(nn.Module):
    def __init__(
            self,
            input_dim=mnist_dim,
            hidden_dim=hidden_dim,
            output_dim=output_dim,
            dropout=0.5,
            criterion__reduce=False

        ):
        super(ClassifierModule, self).__init__()
        self.dropout = nn.Dropout(dropout)

        self.hidden = nn.Linear(input_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, output_dim)


        
    def forward(self, Xdict, **kwargs):
        X = Xdict['data']
        X = F.relu(self.hidden(X))
        X = self.dropout(X)
        X = F.softmax(self.output(X), dim=-1)
        return X
    
    def get_loss(self, y_pred, y_true, X, *args, **kwargs):
        # override get_loss to use the sample_weight from X
        loss_unreduced = super().get_loss(y_pred, y_true, X, *args, **kwargs)
        sample_weight = X['sample_weight']
        loss_reduced = (sample_weight * loss_unreduced).mean()

torch.manual_seed(0)

net = NeuralNetClassifier(
    ClassifierModule,
    max_epochs=20,
    lr=0.1,
    device=device
)

abc = AdaBoostClassifier(NeuralNetClassifier(
    ClassifierModule,
    max_epochs=20,
    lr=0.1,
    device=device,
),n_estimators=7,learning_rate=.9)

abc.fit(X_train, y_train)

@QuantumChamploo
Copy link
Author

This is essentially what I would like to get to work. I know that this is not the rightway to implement sample_weights, as I could not make a simple example work, but I thought it would show what i am trying to get down. Making ADA work with skorch would be amazing, and it seems like it could be done

@BenjaminBossan
Copy link
Collaborator

BenjaminBossan commented Mar 22, 2020

I got something to work. If it's really correct, I'm not sure, since that would require to dig into the sklearn implementation of adaboost. But at least for me, the model trains and gives reasonable results. Below are just the changed lines of code from your example, the rest is the same:

from skorch.dataset import get_len
from skorch.utils import to_device
import warnings

class ClassifierModule(nn.Module):
    def __init__(
        self,
        input_dim=mnist_dim,
        hidden_dim=hidden_dim,
        output_dim=output_dim,
        dropout=0.5,
    ):
        super(ClassifierModule, self).__init__()

        self.dropout = nn.Dropout(dropout)
        self.hidden = nn.Linear(input_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, X, **kwargs):
        X = F.relu(self.hidden(X))
        X = self.dropout(X)
        X = F.softmax(self.output(X), dim=-1)
        return X

class AdaboostMixin:
    """Mixin class to make the neural net compatible with sklearn's AdaBoost"""
    def fit(self, X, y=None, sample_weight=None, **fit_params):
        if getattr(self, 'criterion__reduction', True) != 'none':
            raise ValueError("Please pass argument criterion__reduction='none'")

        if sample_weight is None:
            raise ValueError("Must pass sample weights")

        m, n = len(X), len(sample_weight)
        if m != n:
            raise ValueError("X and sample_weight have different lengths: {} vs {}".format(m, n))

        if isinstance(X, dict) and ('sample_weight' in X):
            raise ValueError("X already contains a key called sample_weight")

        if not isinstance(X, dict):
            warnings.warn("Converting input data to dict")
            X = {'X': X}
        else:
            X = X.copy()

        X['sample_weight'] = sample_weight
        return super().fit(X, y, **fit_params)

    def get_loss(self, y_pred, y_true, X, *args, **kwargs):
        # override get_loss to use the sample_weight from X
        loss_unreduced = super().get_loss(y_pred, y_true, X, *args, **kwargs)
        sample_weight = X['sample_weight']
        sample_weight = to_device(sample_weight, self.device)
        loss_reduced = (sample_weight * loss_unreduced).mean()
        return loss_reduced

class MyNeuralNetClassifier(AdaboostMixin, NeuralNetClassifier):
    pass

abc = AdaBoostClassifier(
    MyNeuralNetClassifier(
        ClassifierModule,
        max_epochs=5,
        lr=0.1,  # <- probably a much bigger learning rate is required
        device=device,
        criterion__reduction='none',
        # train_split=False,  # <- consider turining off the train split
    ),
    n_estimators=5,
    learning_rate=.9
)

With those hyper-parameters, each of the 5 nets gets a valid accuracy of 0.5 and the ensemble achieves 0.79.

The main change that was necessary is to add sample_weight to the signature if fit, since that is why sklearn raises the error you saw. Next we inject the sample weights into X. It's a bit hacky but serves the purpose. As you can see, there is not really a lot of new code, most of the code is just checking that everything is as expected.

If this works for you as well (or you still have trouble) please report back. Also if you have further insights that this works correctly. If this works as expected, we could consider adding AdaboostMixin to our helpers.

@QuantumChamploo
Copy link
Author

Wow thank you so so much. This is a big break through for me and is deeply appreciated. made my week really. I will run this by my PI, and see what tests we need to do to make sure it working the way we think it is, but it seems great. Ill get back to you soon about. I got slightly different base accuracies for the individual nets. Were they all at .5 acc, or just around there?

Thanks again

@BenjaminBossan
Copy link
Collaborator

I really hope this works out for you, good luck.

There was some variance between nets: 0.5102, 0.5064, 0.6207, 0.5873, 0.5390 (valid accuracy).

For serious training, I would turn off skorch-internal validation though (train_split=False) and probably also logs (verbose=0), I just left it on out of curiosity.

@thomasjpfan
Copy link
Member

The main change that was necessary is to add sample_weight to the signature if fit, since that is why sklearn raises the error you saw.

Yea this is a thing. The mixin does more than "enables adaboost", it makes it easy to enable sample_weights.

Quick thought: Isn't loss_unreduced is scalar here? If so, I think we wish to multiply the sample_weights with the "pointwise loss" and not an aggregation for the batch.

    def get_loss(self, y_pred, y_true, X, *args, **kwargs):
        # override get_loss to use the sample_weight from X
        loss_unreduced = super().get_loss(y_pred, y_true, X, *args, **kwargs)

(I didn't know there were uses using adaboost with neural networks. Sounds fun tho :D)

@BenjaminBossan
Copy link
Collaborator

Yea this is a thing. The mixin does more than "enables adaboost", it makes it easy to enable sample_weights.

Yes, you're right, AdaboostMixin would be a misnomer.

Quick thought: Isn't loss_unreduced is scalar here? If so, I think we wish to multiply the sample_weights with the "pointwise loss" and not an aggregation for the batch.

Sorry, I didn't get the point.

@thomasjpfan
Copy link
Member

Sorry, I didn't get the

As in loss_unreduced will be a single number for the batch. The self.criterion_(y_pred, y_true) in get_loss will already do an aggregation that weights the samples uniformly. For example:

criterion = nn.NLLLoss()
y_pred = torch.Tensor(
    [[0.3, 0.1, 0.6],
     [0.4, 0.5, 0.1],
     [0.2, 0.2, 0.6]])
y_true = torch.tensor([1, 0, 2])
criterion(y_pred, y_true)
# tensor(-0.3667)

Lets say the sample_weights were [1, 2, 2], it would not be able to correct adjust the loss if it is a scaler.

@BenjaminBossan
Copy link
Collaborator

As in loss_unreduced will be a single number for the batch

It shouldn't be because we have criterion__reduction='none',. The mixin class also checks for that.

@rasbt
Copy link

rasbt commented Apr 7, 2020

I agree with @BenjaminBossan -- I think the implementation should be fine due to the lines:

loss_reduced = (sample_weight * loss_unreduced).mean()
        return loss_reduced

where loss_unreduced would be a list of the losses for each example in the current minibatch (not a scalar)

Regarding

Quick thought: Isn't loss_unreduced is scalar here? If so, I think we wish to multiply the sample_weights with the "pointwise loss" and not an aggregation for the batch.

there is a valid concern though that I have. So, when I see this correctly -- I am not that familiar with skorch -- it automatically divides the dataset into minibatches for stochastic gradient descent. I am not sure if that's ideal for AdaBoost, because the sample_weight vector would then also be of size "minibatch size." I.e., a weight at a certain position won't be corresponding to a particular training example during training, since the minibatches change.

To make it work, I think the code needs to be modified such that sample_weight has the same size as the dataset (number of training examples), and one would have to keep track of which weight corresponds to which training example. This would probably super complicated to implement, so I think that AdaBoost with this setting would really only work with "batch gradient descent" rather than "minibatch/stochastic gradient descent"?

@BenjaminBossan
Copy link
Collaborator

it automatically divides the dataset into minibatches for stochastic gradient descent

Correct, though you could set the batch size to be the size of the dataset (batch_size=-1) to effectively turn off batching. But training neural nets with very big batch sizes is delicate.

the sample_weight vector would then also be of size "minibatch size." I.e., a weight at a certain position won't be corresponding to a particular training example during training, since the minibatches change.

This point I don't get. The sample weights will be batched together with the training data (in fact, it's treated the same as training data for batching purposes), so if a mini batch starts with a sample from position 1234, the first sample weight will also be the one from position 1234. Or do you mean something else?

Btw., I edited my snippet above. I removed the line sample_weight = sample_weight / sample_weight.sum() # normalize to 1 because sklearn already takes care of normalizing the sample weights to 1. However, that means that for each mini batch, the sum of the sample weights will be << 1. This in turn leads to losses with very small magnitude. Therefore, we need a very big learning rate (for this example 100) to compensate. It's not ideal but normalizing per mini batch would be incorrect.

@rasbt
Copy link

rasbt commented Apr 7, 2020

This point I don't get. The sample weights will be batched together with the training data (in fact, it's treated the same as training data for batching purposes), so if a mini batch starts with a sample from position 1234, the first sample weight will also be the one from position 1234. Or do you mean something else?

That's good, this would address my concern then :).

@BenjaminBossan
Copy link
Collaborator

Great. Then let's wait for @QuantumChamploo to validate the results.

@BenjaminBossan
Copy link
Collaborator

@QuantumChamploo did you have the time to try this yet?

@QuantumChamploo
Copy link
Author

QuantumChamploo commented May 23, 2020

@BenjaminBossan yes, I have been testing this implementation. Unfortunately I came into a very strange error when trying to use it for analysis, where I was more rigorously testing the ensembles effectiveness. While trying to create a concise reproduction of the error, I figured out a work around, but I would still like to share the error.

What I am trying to do is sweep through the learning rate at 3 different values (.1,.01,.001) 3 times each time. The only difference in the two following code snippets are that i flipped the ordering of the two loops while taking these accuracy values

"""

xs = []
ys = []


for j in range(3):

    x = []
    y = []



    for i in range(3):
        hld = .1**(i+1)
        print(hld)


        abc5 = AdaBoostClassifier(
            MyNeuralNetClassifier(
                ClassifierModule,
                max_epochs=5,
                lr=hld,
                device=device,
                optimizer=torch.optim.Adam,
                criterion__reduction='none',
            # train_split=False,  # <- consider turining off the train split
            ),
            n_estimators=5,
            learning_rate=.9
        )
        abc5.fit(X_train.astype(np.float32), y_train.astype(np.int64))
        y_pred = abc5.predict(X_test.astype(np.float32))
        x.append(abc5.estimators_[0].score(X_test.astype(np.float32),y_test))
        y.append(metrics.accuracy_score(y_test, y_pred))
    
    xs.append(x)
    ys.append(y)

"""

Screen Shot 2020-05-23 at 5 02 11 PM

"""

x_dists = []
y_dist = []

for j in range(3):
    x = []
    y = []
    for i in range(3):
    
        abc = AdaBoostClassifier(
            MyNeuralNetClassifier(
                ClassifierModule,
                max_epochs=5,
                lr=.1**(j+1),
                device=device,
                optimizer=torch.optim.Adam,
                criterion__reduction='none',
            # train_split=False,  # <- consider turining off the train split
            ),
            n_estimators=5,
            learning_rate=.9
        )
        abc.fit(X_train.astype(np.float32), y_train.astype(np.int64))
        y_pred = abc.predict(X_test.astype(np.float32))
        x.append(abc.estimators_[0].score(X_test.astype(np.float32),y_test))
        y.append(metrics.accuracy_score(y_test, y_pred))
    
    x_dists.append(x)
    y_dist.append(y)

"""

Screen Shot 2020-05-23 at 5 08 52 PM

So in the first screen shot each array represents a cycle through the parameters, while in the second each array is just one value of the hyper parameter. In the second snippet, the trend is not what we expected or have seen when doing individual runs. The values seem much more random, while the first snippet has the clear trend that we expect.

This was very strange, and was a big roadblock. It seems I found a work around, but any clue why this would happen? some sort of issue with sharing weights from instantiation to instantiation is my guess.

@BenjaminBossan
Copy link
Collaborator

Thx for reporting back @QuantumChamploo

So if I understand correctly, the values for x_dist and y_dist should be similar for the two experiments, just transposed. Indeed, the problem could be related to the model not being properly reset. Which at first seems strange since you initialize a completely new net and module in the inner loop. So if that's the issue, it must relate to something (presumably the use of clone) that happens inside of AdaBoostClassifier.

We had a recent fix (#617) that looks like it could be related. Could you try the current master branch from skorch and see if that fixes your issue?

@QuantumChamploo
Copy link
Author

@BenjaminBossan sorry if this is a silly question, but to pull the current master branch from skorch, i would not want to reinstall with pip, but pull the master branch to my local python enviroment?

@BenjaminBossan
Copy link
Collaborator

I don't know exactly how you set up your local environment. Typically, it should work if you pip uninstall skorch first, then follow the instructions here.

However, if you don't want to clone the github repo, pip also has an option to install from github, explained in this SO. So I believe pip install git+https://github.com/skorch-dev/skorch should be sufficient.

@QuantumChamploo
Copy link
Author

So I re-did the code with the current master branch and got the same erroneous distribution
image

@rasbt
Copy link

rasbt commented Jun 10, 2020

Just to make sure that the correct version was installed, could you execute

skorch.__version__  

in the cell above your results, just to make sure that it's indeed the expected version. Jupyter Notebook/Lab is sometimes a bit weird with handling / respecting different virtual environments as I found out in recent debugging projects ...

@BenjaminBossan
Copy link
Collaborator

Though the version doesn't tell you if the code is from the current master or not.

@BenjaminBossan
Copy link
Collaborator

@QuantumChamploo if you could make it possible to replicate your problem, I would gladly take a look at it.

@rasbt
Copy link

rasbt commented Jun 11, 2020

Though the version doesn't tell you if the code is from the current master or not.

Oh you are right. I just see that the latest version on PyPI is the same version as in the master branch. Usually, recommended practice reg. Python's open source sci-stack is to bump the version number in the master for the dev version. I.e., 0.8.1dev or 0.9.0a0. Both PyTorch and Scikit-learn do this. For ease of debugging etc. I wonder what you think about adopting it as well here?

@BenjaminBossan
Copy link
Collaborator

Your are correct, it would be nice if we bumped the version after release. We rarely do patch level releases, but probably 0.8.1dev is still the safe bet. @ottonemo any opinion?

@rasbt
Copy link

rasbt commented Jun 11, 2020

We rarely do patch level releases, but probably 0.8.1dev is still the safe bet.

Sure. It's often also only used as a placeholder. I.e., you use 0.8.1dev during dev and then later swap it with 0.8.1, 0.9.0, or 1.0.0 -- whatever is most appropriate

@QuantumChamploo
Copy link
Author

QuantumChamploo commented Jun 20, 2020

below are my two pieces of code

the fixed
'''
# gitFix.py

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
from sklearn import metrics

import torch
from torch import nn
import torch.nn.functional as F

from skorch import NeuralNetClassifier

from sklearn.ensemble import AdaBoostClassifier
from skorch.dataset import get_len
import warnings
import random

mnist = fetch_openml('mnist_784', cache=False)

x = mnist.data.astype('float32')
y = mnist.target.astype('int64')

x /= 255.0


X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.25, random_state=42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

mnist_dim = x.shape[1]
#hidden_dim = int(mnist_dim/8)
hidden_dim =  8
output_dim = len(np.unique(mnist.target))

X = {'data': X_train}
X['sample_weight'] = np.ones(len(X_train), dtype=float)



class ClassifierModule(nn.Module):
    def __init__(
        self,
        input_dim=mnist_dim,
        hidden_dim=hidden_dim,
        output_dim=output_dim,
        dropout=0.5,
    ):
        super(ClassifierModule, self).__init__()

        self.dropout = nn.Dropout(dropout)
        self.hidden = nn.Linear(input_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, X, **kwargs):
        X = F.relu(self.hidden(X))
        X = self.dropout(X)
        X = F.softmax(self.output(X), dim=-1)
        return X

class AdaboostMixin:
    """Mixin class to make the neural net compatible with sklearn's AdaBoost"""
    def fit(self, X, y=None, sample_weight=None, **fit_params):
        if getattr(self, 'criterion__reduction', True) != 'none':
            raise ValueError("Please pass argument criterion__reduction='none'")

        if sample_weight is None:
            raise ValueError("Must pass sample weights")

        m, n = len(X), len(sample_weight)
        if m != n:
            raise ValueError("X and sample_weight have different lengths: {} vs {}".format(m, n))

        if isinstance(X, dict) and ('sample_weight' in X):
            raise ValueError("X already contains a key called sample_weight")

        if not isinstance(X, dict):
            warnings.warn("Converting input data to dict")
            X = {'X': X}
        else:
            X = X.copy()

        X['sample_weight'] = sample_weight
        return super().fit(X, y, **fit_params)

    def get_loss(self, y_pred, y_true, X, *args, **kwargs):
        # override get_loss to use the sample_weight from X
        loss_unreduced = super().get_loss(y_pred, y_true, X, *args, **kwargs)
        sample_weight = X['sample_weight']
        sample_weight = sample_weight / sample_weight.sum()  # normalize to 1
        loss_reduced = (sample_weight * loss_unreduced).mean()
        return loss_reduced

class MyNeuralNetClassifier(AdaboostMixin, NeuralNetClassifier):
    pass

xs = []
ys = []


for j in range(3):

    x = []
    y = []



    for i in range(3):
        hld = .1**(i+1)
        print(hld)


        abc5 = AdaBoostClassifier(
            MyNeuralNetClassifier(
                ClassifierModule,
                max_epochs=5,
                lr=hld,
                device=device,
                optimizer=torch.optim.Adam,
                criterion__reduction='none',
                # train_split=False,  # <- consider turining off the train split
            ),
            n_estimators=5,
            learning_rate=.9
        )
        abc5.fit(X_train.astype(np.float32), y_train.astype(np.int64))
        y_pred = abc5.predict(X_test.astype(np.float32))
        x.append(abc5.estimators_[0].score(X_test.astype(np.float32),y_test))
        y.append(metrics.accuracy_score(y_test, y_pred))
        
    xs.append(x)
    ys.append(y)


print(xs)
print(ys)

'''
And the broken version
'''
# gitBroken.py

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
from sklearn import metrics

import torch
from torch import nn
import torch.nn.functional as F

from skorch import NeuralNetClassifier

from sklearn.ensemble import AdaBoostClassifier
from skorch.dataset import get_len
import warnings

from sklearn.datasets import fetch_openml
import datetime
import pandas as pd
import re

mnist = fetch_openml('mnist_784', cache=False)

x = mnist.data.astype('float32')
y = mnist.target.astype('int64')

x /= 255.0


X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.25, random_state=42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

X_train = np.array(X_train, dtype=float)
y_train = np.array(y_train, dtype=float)
X_train /= 255.0

mnist_dim = len(X_train[0])
output_dim = 10

X = {'data': X_train}
X['sample_weight'] = np.ones(len(X_train), dtype=float)

data_array = []
x_dists = []
y_dist = []


class ClassifierModule(nn.Module):
    def __init__(
        self,
        input_dim=mnist_dim,
        hidden_dim=8,
        output_dim=output_dim,
        dropout=0.5,
    ):
        super(ClassifierModule, self).__init__()

        self.dropout = nn.Dropout(dropout)
        self.hidden = nn.Linear(input_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, output_dim)

    def forward(self, X, **kwargs):
        X = F.relu(self.hidden(X))
        X = self.dropout(X)
        X = F.softmax(self.output(X), dim=-1)
        return X

class AdaboostMixin:
    """Mixin class to make the neural net compatible with sklearn's AdaBoost"""
    def fit(self, X, y=None, sample_weight=None, **fit_params):
        if getattr(self, 'criterion__reduction', True) != 'none':
            raise ValueError("Please pass argument criterion__reduction='none'")

        if sample_weight is None:
            raise ValueError("Must pass sample weights")

        m, n = len(X), len(sample_weight)
        if m != n:
            raise ValueError("X and sample_weight have different lengths: {} vs {}".format(m, n))

        if isinstance(X, dict) and ('sample_weight' in X):
            raise ValueError("X already contains a key called sample_weight")

        if not isinstance(X, dict):
            warnings.warn("Converting input data to dict")
            X = {'X': X}
        else:
            X = X.copy()

        X['sample_weight'] = sample_weight
        return super().fit(X, y, **fit_params)

    def get_loss(self, y_pred, y_true, X, *args, **kwargs):
        # override get_loss to use the sample_weight from X
        loss_unreduced = super().get_loss(y_pred, y_true, X, *args, **kwargs)
        sample_weight = X['sample_weight']
        sample_weight = sample_weight / sample_weight.sum()  # normalize to 1
        loss_reduced = (sample_weight * loss_unreduced).mean()
        return loss_reduced

class MyNeuralNetClassifier(AdaboostMixin, NeuralNetClassifier):
    pass


for j in range(3):
    x = []
    y = []
    for i in range(3):
        
        abc = AdaBoostClassifier(
            MyNeuralNetClassifier(
                ClassifierModule,
                max_epochs=5,
                lr=.1**(j+1),
                device=device,
                optimizer=torch.optim.Adam,
                criterion__reduction='none',
                # train_split=False,  # <- consider turining off the train split
            ),
            n_estimators=5,
            learning_rate=.9
        )
        abc.fit(X_train.astype(np.float32), y_train.astype(np.int64))
        y_pred = abc.predict(X_test.astype(np.float32))
        x.append(abc.estimators_[0].score(X_test.astype(np.float32),y_test))
        y.append(metrics.accuracy_score(y_test, y_pred))
        
    x_dists.append(x)
    y_dist.append(y)
    print('done with another')

print(x_dists)
print(y_dist)

'''

@BenjaminBossan
Copy link
Collaborator

@QuantumChamploo Thanks for the code, I tested it on the current master branch and this is what I got:

>>> # first example
>>> x_dists
[[0.19582857142857144, 0.7357714285714285, 0.868],
 [0.1944, 0.7064, 0.8588571428571429],
 [0.11297142857142857, 0.8406857142857143, 0.8203428571428572]]

>>> # second example
>>> x_dists
[[0.18571428571428572, 0.23194285714285715, 0.18457142857142858],
 [0.8293142857142857, 0.8109142857142857, 0.8523428571428572],
 [0.854, 0.8609714285714286, 0.8674285714285714]]

>>> # first example
>>> y_dist
[[0.3450857142857143, 0.8408, 0.8953142857142857],
 [0.3110857142857143, 0.8125142857142857, 0.8893714285714286],
 [0.24577142857142859, 0.8480571428571428, 0.8936571428571428]]
>>> # second example
>>> y_dist
[[0.42028571428571426, 0.4465142857142857, 0.26794285714285715],
 [0.8168, 0.8121714285714285, 0.8471428571428572],
 [0.8905714285714286, 0.8922285714285715, 0.8905714285714286]]

So I believe this looks correct, right? The second example is roughly the same as the first, only transposed. There is some randomness in the scores, but that's probably within the expected range.

Assuming this is indeed correct, the bugfix mentioned above is probably the reason for why it works now and you need to make sure to really work on the current master (with the version bump, it should now be easier to verify, it should be 0.8.1dev).

@QuantumChamploo
Copy link
Author

QuantumChamploo commented Jul 1, 2020

@BenjaminBossan yes this looks great! ill bump to the current master and give this a try.
I wanted to ask also, I have been trying to implement gpu with the ada worked on above. I have tested a simple model from the skorch examples on my workstation, and the gpu is working. With the ada implementation, i get the following error:

Screen Shot 2020-07-01 at 4 52 25 PM

It seems the AdaboostMixin class is getting tripped up on the gpu device. The following is code that should reproduce the issue:

    from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
from sklearn import metrics

import torch
from torch import nn
import torch.nn.functional as F

from skorch import NeuralNetClassifier

from sklearn.ensemble import AdaBoostClassifier
from skorch.dataset import get_len
import warnings


mnist = fetch_openml('mnist_784', cache=False)

x = mnist.data.astype('float32')
y = mnist.target.astype('int64')

x /= 255.0


X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.25, random_state=42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(device)

mnist_dim = x.shape[1]
#hidden_dim = int(mnist_dim/8)
hidden_dim =  8
output_dim = len(np.unique(mnist.target))

X = {'data': X_train}
X['sample_weight'] = np.ones(len(X_train), dtype=float)


class ClassifierModule(nn.Module):
    def __init__(
        self,
        input_dim=mnist_dim,
        hidden_dim=hidden_dim,
        output_dim=output_dim,
        dropout=0.5,
    ):
        super(ClassifierModule, self).__init__()

        self.dropout = nn.Dropout(dropout)
        self.hidden = nn.Linear(input_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, X, **kwargs):
        X = F.relu(self.hidden(X))
        X = self.dropout(X)
        X = F.softmax(self.output(X), dim=-1)
        return X

class AdaboostMixin:
    """Mixin class to make the neural net compatible with sklearn's AdaBoost"""
    def fit(self, X, y=None, sample_weight=None, **fit_params):
        if getattr(self, 'criterion__reduction', True) != 'none':
            raise ValueError("Please pass argument criterion__reduction='none'")

        if sample_weight is None:
            raise ValueError("Must pass sample weights")

        m, n = len(X), len(sample_weight)
        if m != n:
            raise ValueError("X and sample_weight have different lengths: {} vs {}".format(m, n))

        if isinstance(X, dict) and ('sample_weight' in X):
            raise ValueError("X already contains a key called sample_weight")

        if not isinstance(X, dict):
            warnings.warn("Converting input data to dict")
            X = {'X': X}
        else:
            X = X.copy()

        X['sample_weight'] = sample_weight
        return super().fit(X, y, **fit_params)

    def get_loss(self, y_pred, y_true, X, *args, **kwargs):
        # override get_loss to use the sample_weight from X
        loss_unreduced = super().get_loss(y_pred, y_true, X, *args, **kwargs)
        sample_weight = X['sample_weight']
        sample_weight = sample_weight / sample_weight.sum()  # normalize to 1
        loss_reduced = (sample_weight * loss_unreduced).mean()
        return loss_reduced

class MyNeuralNetClassifier(AdaboostMixin, NeuralNetClassifier):
    pass


abc = AdaBoostClassifier(
        MyNeuralNetClassifier(
            ClassifierModule,
            max_epochs=5,
            lr=.001,
            device=device,
            optimizer=torch.optim.Adam,
            iterator_train__shuffle=True,
            iterator_train__num_workers=4,
            iterator_valid__shuffle=True,
            iterator_valid__num_workers=4,
            criterion__reduction='none',
            # train_split=False,  # <- consider turining off the train split
        ),
        n_estimators=5,
        learning_rate=.9
    )

abc.fit(X_train,y_train)
y_pred = abc.predict(X_test)

print("the overall acc")
print(metrics.accuracy_score(y_test, y_pred))

print("individual acc")
print(abc.estimators_[0].score(X_test,y_test))

@BenjaminBossan
Copy link
Collaborator

ill bump to the current master and give this a try.

fingers crossed

It seems the AdaboostMixin class is getting tripped up on the gpu device

Without running your sample code, my guess is that the sample_weight is still on CPU while the loss is on GPU. Maybe you could check their respective devices. If my suspicion is true, this should fix it:

from skorch.utils import to_device

...

        # inside get_loss
        sample_weight = X['sample_weight']
        sample_weight = to_device(sample_weight, self.device)
        ...

@QuantumChamploo
Copy link
Author

QuantumChamploo commented Jul 2, 2020

that seemed to work. I also found that
sample_weight = sample_weight.to(device)

fixed the issues as well. I wanted to note that i used the following commands into the neuralnet object

Screen Shot 2020-07-02 at 4 58 45 PM

I initially had the shuffle flags on, but this ruined the net , i.e made it so that it went to ~10 percent which is just random guessing. The neural net was being trained right, in the verbose output you could see it was training properly, but the output for test accuracy was diminished. My guess is something got shuffled improperly. Its not a big deal, because I got this working but thought I would mention it

@BenjaminBossan
Copy link
Collaborator

I initially had the shuffle flags on, but this ruined the net , i.e made it so that it went to ~10 percent which is just random guessing.

Yes, when you shuffle the validation data, your predictions will be shuffled, and thus are no longer aligned with your targets, which, as you observed, results in random level scores.

@BenjaminBossan
Copy link
Collaborator

I think this issue has been solved. We could consider adding AdaBoostMixin to the skorch code but probably its usefulness is too narrow to make that worthwhile.

@francescamanni1989
Copy link

Hi everyone, this was very useful for me and I tried to apply the same concept to GradientBoosting, while getting this error:

base_estimator = NeuralNetClassifier(ClassifierModule,
max_epochs=100,
#lr=hld,
device=device,
optimizer=torch.optim.Adam,
criterion__reduction='none',
# train_split=False, # <- consider turining off the train split
)

abc = GradientBoostingClassifier(base_estimator, n_estimators=7,learning_rate=.9)

abc.fit(X_train, y_train)

/opt/conda/lib/python3.9/site-packages/sklearn/ensemble/_gb.py in _check_params(self)
237 if (self.loss not in self._SUPPORTED_LOSS
238 or self.loss not in _gb_losses.LOSS_FUNCTIONS):
--> 239 raise ValueError("Loss '{0:s}' not supported. ".format(self.loss))
240
241 if self.loss == 'deviance':
TypeError: unsupported format string passed to NeuralNetClassifier.format

Anyone has any idea on how to make NeuralNetClassifier compatible. I don't understand while the loss parameter of GB is not compatible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants