-
Notifications
You must be signed in to change notification settings - Fork 29
/
17-solutions-nearest-neighbours.Rmd
139 lines (117 loc) · 4.23 KB
/
17-solutions-nearest-neighbours.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Solutions ch. 7 - Nearest neighbours {#solutions-nearest-neighbours}
Solutions to exercises of chapter \@ref(nearest-neighbours).
## Exercise 1
Load libraries
```{r echo=T}
library(caret)
library(RColorBrewer)
library(doMC)
library(corrplot)
```
Prepare for parallel processing
```{r echo=T}
registerDoMC(detectCores())
```
Load data
```{r echo=T}
load("data/wheat_seeds/wheat_seeds.Rda")
```
Partition data
```{r echo=T}
set.seed(42)
trainIndex <- createDataPartition(y=variety, times=1, p=0.7, list=F)
varietyTrain <- variety[trainIndex]
morphTrain <- morphometrics[trainIndex,]
varietyTest <- variety[-trainIndex]
morphTest <- morphometrics[-trainIndex,]
summary(varietyTrain)
summary(varietyTest)
```
Data check: zero and near-zero predictors
```{r echo=T}
nzv <- nearZeroVar(morphTrain, saveMetrics=T)
nzv
```
Data check: are all predictors on same scale?
```{r echo=T}
summary(morphTrain)
```
```{r wheatBoxplots, fig.cap='Boxplots of the 7 geometric parameters in the wheat data set',, out.width='75%', fig.asp=1, fig.align='center', echo=T }
featurePlot(x = morphTrain,
y = varietyTrain,
plot = "box",
## Pass in options to bwplot()
scales = list(y = list(relation="free"),
x = list(rot = 90)),
layout = c(3,3))
```
Data check: pairwise correlations between predictors
```{r wheatCorrelogram, fig.cap='Correlogram of the wheat seed data set.', out.width='75%', fig.asp=1, fig.align='center', echo=T}
corMat <- cor(morphTrain)
corrplot(corMat, order="hclust", tl.cex=1)
```
```{r echo=T}
highCorr <- findCorrelation(corMat, cutoff=0.75)
length(highCorr)
names(morphTrain)[highCorr]
```
Data check: skewness
```{r wheatDensityPlots, fig.cap='Density plots of the 7 geometric parameters in the wheat data set',, out.width='75%', fig.asp=1, fig.align='center', echo=T }
featurePlot(x = morphTrain,
y = varietyTrain,
plot = "density",
## Pass in options to xyplot() to
## make it prettier
scales = list(x = list(relation="free"),
y = list(relation="free")),
adjust = 1.5,
pch = "|",
layout = c(3, 3),
auto.key = list(columns = 3))
```
Create a 'grid' of values of _k_ for evaluation:
```{r echo=T}
tuneParam <- data.frame(k=seq(1,50,2))
```
Generate a list of seeds for reproducibility (optional) based on grid size
```{r echo=T}
set.seed(42)
seeds <- vector(mode = "list", length = 101)
for(i in 1:100) seeds[[i]] <- sample.int(1000, length(tuneParam$k))
seeds[[101]] <- sample.int(1000,1)
```
<!--
Define a pre-processor (named transformations) and transform morphTrain
```{r echo=T}
transformations <- preProcess(morphTrain,
method=c("center", "scale", "corr"),
cutoff=0.75)
morphTrainT <- predict(transformations, morphTrain)
```
-->
Set training parameters. In the example in chapter \@ref(nearest-neighbours) pre-processing was performed outside the cross-validation process to save time for the purposes of the demonstration. Here we have a relatively small data set, so we can do pre-processing within each iteration of the cross-validation process. We specify the option ```preProcOptions=list(cutoff=0.75)``` to set a value for the pairwise correlation coefficient cutoff.
```{r echo=T}
train_ctrl <- trainControl(method="repeatedcv",
number = 10,
repeats = 10,
preProcOptions=list(cutoff=0.75),
seeds = seeds)
```
Run training
```{r echo=T}
knnFit <- train(morphTrain, varietyTrain,
method="knn",
preProcess = c("center", "scale", "corr"),
tuneGrid=tuneParam,
trControl=train_ctrl)
knnFit
```
Plot cross validation accuracy as a function of _k_
```{r cvAccuracyMorphTrain, fig.cap='Accuracy (repeated cross-validation) as a function of neighbourhood size for the wheat seeds data set.', out.width='100%', fig.asp=0.6, fig.align='center', echo=T}
plot(knnFit)
```
Predict the class (wheat variety) of the observations in the test set.
```{r echo=T}
test_pred <- predict(knnFit, morphTest)
confusionMatrix(test_pred, varietyTest)
```