-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_code.txt
37 lines (26 loc) · 1.13 KB
/
main_code.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
imds = imageDatastore("C:\Users\Kishor\Downloads\deeplearning_course_files\Roundworms\WormImages");
groundtruth = readtable("WormData");
imds.Labels = categorical(groundtruth.Status);
imshow(readimage(imds,1))
imshow(readimage(imds,2))
imshow(readimage(imds,3))
[trainImgs,testImgs] = splitEachLabel(imds,0.6,"randomized");
trainds = augmentedImageDatastore([224 224],trainImgs,"ColorPreprocessing","gray2rgb");
testds = augmentedImageDatastore([224 224],testImgs,"ColorPreprocessing","gray2rgb");
net = googlenet;
lgraph = layerGraph(net);
newFc = fullyConnectedLayer(2,"Name","new_fc")
lgraph = replaceLayer(lgraph,"loss3-classifier",newFc)
newOut = classificationLayer("Name","new_out")
lgraph = replaceLayer(lgraph,"output",newOut)
options = trainingOptions("sgdm","InitialLearnRate", 0.001, MaxEpochs=30);
wormsnet = trainNetwork(trainds,lgraph,options)
preds = classify(wormsnet,testds);
truetest = testImgs.Labels;
nnz(preds == truetest)/numel(preds)
confusionchart(truetest,preds);
idx = find(preds~=truetest)
if ~isempty(idx)
imshow(readimage(testImgs,idx(1)))
title(truetest(idx(1)))
end