You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Sometimes I want to use custom loss functions on a subset of parameters.
My way of doing this was based on #939.
The regularization I was using was using fine with Tracker, but now with Zygote the loss keeps increasing all of a sudden (and appears to be slower).
For example, consider:
using Distributions
using Random
using Flux
using Flux: onehotbatch, onecold, crossentropy, throttle
imgs = Flux.Data.MNIST.images()
labels = Flux.Data.MNIST.labels();
## Boring Preprocessing
X =hcat(float.(reshape.(imgs, :))...)
Y =onehotbatch(labels, 0:9);
Xtraining = X[:, 1:10000]
Ytraining = Y[:, 1:10000]
# Function that grabs the biases vectors from the Flux parametersfunctionbias_params(m::Chain, ps=Flux.Params())
map((l)->bias_params(l, ps), m.layers)
ps
endbias_params(m::Dense, ps=Flux.Params()) =push!(ps, m.b)
bias_params(m, ps=Flux.Params()) = ps
accuracyLoss(m, x, y) =crossentropy(m(x), y)
functionbiasTestLoss(m)
Float32(sum([sum(x.^2) for x inbias_params(m)]))
endloss(m, x, y) =accuracyLoss(m, x, y) +biasTestLoss(m)
# Train this model with the custom loss function
model =Chain(
Dense(28^2, 32, σ),
Dropout(0.1),
Dense(32, 10, σ),
Dropout(0.1),
Dense(10, 10, σ),
Dropout(0.1),
Dense(10, 10, σ),
softmax)
dataset = (X_training, Y_training)
epochs =100
updateEpochs =10
batchSize =50accuracy(m, x, y) =mean(onecold(m(x)) .==onecold(y)
shuffled_indices =randperm(size(X_training)[2])
shuffled_X_training = X_training[:, shuffled_indices]
shuffled_Y_training = Y[:, shuffled_indices]
N =size(X_training)[2]
batched_training_set = [[(cat(Float32.(shuffled_X_training[:, i]), dims=2), shuffled_Y_training[:, i])] for i in Iterators.partition(1:N, batchSize)]
lossRecord = []
for epoch in1:epochs
for n in1:length(batched_training_set)
Flux.train!((x, y) ->loss(model, x, y), Flux.params(model), batched_training_set[n], ADAM())
endif epoch % updateEpochs ==0println("At epoch ", epoch, " the loss is: ", loss(model, X_training, Y))
endpush!(lossRecord, (epoch, loss(model, X_training, Y)))
endprintln("The training accuracy of the network is: ", accuracy(model, X_training, Y))
This strange behaviour mystifies me as, like I said, this method used to work and gave expected results with the Tracker backend.
The text was updated successfully, but these errors were encountered:
Unlike Tracker, Zygote is generally not great with a lot of mutation. Thankfully, this kind of pattern is now trivial with the new modules function: https://fluxml.ai/Flux.jl/stable/utilities/#Flux.modules. In light of that, I think this can be closed.
Hello all,
Sometimes I want to use custom loss functions on a subset of parameters.
My way of doing this was based on #939.
The regularization I was using was using fine with Tracker, but now with Zygote the loss keeps increasing all of a sudden (and appears to be slower).
For example, consider:
This strange behaviour mystifies me as, like I said, this method used to work and gave expected results with the Tracker backend.
The text was updated successfully, but these errors were encountered: