-
Notifications
You must be signed in to change notification settings - Fork 18
/
mkkmeans_train.R
executable file
·45 lines (40 loc) · 1.55 KB
/
mkkmeans_train.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
library(Rmosek)
mkkmeans_train <- function(Km, parameters) {
state <- list()
state$time <- system.time({
P <- dim(Km)[3]
theta <- rep(1 / P, P)
K_theta <- matrix(0, nrow(Km), ncol(Km))
for (m in 1:P) {
K_theta <- K_theta + theta[m]^2 * Km[,,m]
}
objective <- rep(0, parameters$iteration_count)
for (iter in 1:parameters$iteration_count) {
print(sprintf("running iteration %d...", iter))
H <- eigen(K_theta, symmetric = TRUE)$vectors[, 1:parameters$cluster_count]
problem <- list()
problem$sense <- "min"
problem$c <- rep(0, P)
problem$A <- Matrix(1, nrow = 1, ncol = P, sparse = TRUE)
problem$bc <- rbind(blc = 1, buc = 1)
problem$bx <- rbind(blx = rep(0, P), bux = rep(1, P))
problem$qobj <- list(i = 1:P, j = 1:P, v = sapply(1:P, function(m) {sum(diag(Km[,,m])) - sum(diag(t(H) %*% Km[,,m] %*% H))}))
opts <- list()
opts$verbose <- 0
result <- mosek(problem, opts)
theta <- result$sol$itr$xx
K_theta <- matrix(0, nrow(Km), ncol(Km))
for (m in 1:P) {
K_theta <- K_theta + theta[m]^2 * Km[,,m]
}
objective[iter] <- sum(diag(t(H) %*% K_theta %*% H)) - sum(diag(K_theta))
}
H_normalized <- H / matrix(sqrt(rowSums(H^2, 2)), nrow(H), parameters$cluster_count, byrow = FALSE)
set.seed(NULL)
state$clustering <- kmeans(H_normalized, centers = parameters$cluster_count, iter.max = 1000, nstart = 10)$cluster
state$objective <- objective
state$parameters <- parameters
state$theta <- theta
})
return(state)
}