Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom loss function on subset of parameters fails #1371

Closed
Antomek opened this issue Oct 22, 2020 · 2 comments
Closed

Custom loss function on subset of parameters fails #1371

Antomek opened this issue Oct 22, 2020 · 2 comments

Comments

@Antomek
Copy link

Antomek commented Oct 22, 2020

Hello all,

Sometimes I want to use custom loss functions on a subset of parameters.
My way of doing this was based on #939.
The regularization I was using was using fine with Tracker, but now with Zygote the loss keeps increasing all of a sudden (and appears to be slower).

For example, consider:

using Distributions
using Random
using Flux
using Flux: onehotbatch, onecold, crossentropy, throttle

imgs = Flux.Data.MNIST.images()
labels = Flux.Data.MNIST.labels();
## Boring Preprocessing
X = hcat(float.(reshape.(imgs, :))...)
Y = onehotbatch(labels, 0:9);

Xtraining = X[:, 1:10000]
Ytraining = Y[:, 1:10000]

# Function that grabs the biases vectors from the Flux parameters
function bias_params(m::Chain, ps=Flux.Params())
    map((l)->bias_params(l, ps), m.layers)
    ps
end
bias_params(m::Dense, ps=Flux.Params()) = push!(ps, m.b)
bias_params(m, ps=Flux.Params()) = ps

accuracyLoss(m, x, y) = crossentropy(m(x), y)
function biasTestLoss(m)
    Float32(sum([sum(x.^2) for x in bias_params(m)]))
end

loss(m, x, y) = accuracyLoss(m, x, y) + biasTestLoss(m)

# Train this model with the custom loss function
model = Chain(
    Dense(28^2, 32, σ),
    Dropout(0.1),
    Dense(32, 10, σ),
    Dropout(0.1),
    Dense(10, 10, σ),
    Dropout(0.1),
    Dense(10, 10, σ),
    softmax)

dataset = (X_training, Y_training)
epochs = 100
updateEpochs = 10
batchSize = 50
accuracy(m, x, y) = mean(onecold(m(x)) .== onecold(y)

shuffled_indices = randperm(size(X_training)[2])
shuffled_X_training = X_training[:, shuffled_indices]
shuffled_Y_training = Y[:, shuffled_indices]
N = size(X_training)[2]
batched_training_set = [[(cat(Float32.(shuffled_X_training[:, i]), dims=2), shuffled_Y_training[:, i])] for i in Iterators.partition(1:N, batchSize)]

lossRecord = []

for epoch in 1:epochs
    for n in 1:length(batched_training_set)
        Flux.train!((x, y) -> loss(model, x, y), Flux.params(model), batched_training_set[n], ADAM())
    end

    if epoch % updateEpochs == 0
        println("At epoch ", epoch, " the loss is: ", loss(model, X_training, Y))
    end

    push!(lossRecord, (epoch, loss(model, X_training, Y)))
end
println("The training accuracy of the network is: ", accuracy(model, X_training, Y))

This strange behaviour mystifies me as, like I said, this method used to work and gave expected results with the Tracker backend.

@CarloLucibello
Copy link
Member

I could not run the script, contains a few syntax errors

@ToucheSir
Copy link
Member

Unlike Tracker, Zygote is generally not great with a lot of mutation. Thankfully, this kind of pattern is now trivial with the new modules function: https://fluxml.ai/Flux.jl/stable/utilities/#Flux.modules. In light of that, I think this can be closed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants