-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_segnet.m
63 lines (48 loc) · 1.83 KB
/
train_segnet.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
trainImagesFile = 'train-images-idx3-ubyte';
testImagesFile = 't10k-images-idx3-ubyte';
testLabelsFile = 't10k-labels-idx1-ubyte';
XTrain = processImagesMNIST(trainImagesFile);
numTrainImages = size(XTrain,4);
XTest = processImagesMNIST(testImagesFile);
YTest = processLabelsMNIST(testLabelsFile);
imageSize = [28 28 1];
numClasses = 2;
encoderDepth = 1;
seglayers = segnetLayers(imageSize,numClasses,encoderDepth);
seglayers = removeLayers(seglayers,{'pixelLabels','softmax'});
seglayers = replaceLayer(seglayers,'inputImage',imageInputLayer(imageSize,'Name','input_encoder','Normalization','none'));
figure
plot(seglayers)
segnet = dlnetwork(seglayers);
executionEnvironment = "auto";
numEpochs = 20;
miniBatchSize = 512;
lr = 1e-3;
numIterations = floor(numTrainImages/miniBatchSize);
iteration = 0;
avgGradients = [];
avgGradientsSquared = [];
for epoch = 1:numEpochs
tic;
for i = 1:numIterations
iteration = iteration + 1;
idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
XBatch = XTrain(:,:,:,idx);
XBatch = dlarray(single(XBatch), 'SSCB');
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
XBatch = gpuArray(XBatch);
end
grad = dlfeval(...
@segGradients, segnet, XBatch);
[segnet.Learnables, avgGradients, avgGradientsSquared] = ...
adamupdate(segnet.Learnables, ...
grad, avgGradients, avgGradientsSquared, iteration, lr);
end
elapsedTime = toc;
xPred = sigmoid(forward(segnet, XTest));
elbo = ELBOloss(XTest, xPred, 1, 0);
disp("Epoch : "+epoch+" Test ELBO loss = "+gather(extractdata(elbo))+...
". Time taken for epoch = "+ elapsedTime + "s")
end
visualizeSegReconstruction(XTest,YTest,segnet)
save segnet.mat segnet