Skip to content

Commit

Permalink
Merge pull request #267 from JuliaStats/ast/seeding_cleanup
Browse files Browse the repository at this point in the history
Cleanup kmeans() seeding
  • Loading branch information
alyst authored Dec 19, 2023
2 parents 377678d + 0167254 commit 0bade2b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 29 deletions.
19 changes: 4 additions & 15 deletions src/seeding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,8 @@ function initseeds!(iseeds::AbstractVector{<:Integer}, alg::KmppAlg,
iseeds[j] = p

# update mincosts
c = view(X, :, p)
colwise!(metric, tmpcosts, X, view(X, :, p))
updatemin!(mincosts, tmpcosts)
mincosts .= min.(mincosts, tmpcosts)
mincosts[p] = 0
end
end
Expand Down Expand Up @@ -211,7 +210,7 @@ function initseeds_by_costs!(iseeds::AbstractVector{<:Integer}, alg::KmppAlg,
iseeds[j] = p

# update mincosts
updatemin!(mincosts, view(costs, :, p))
mincosts .= min.(mincosts, view(costs, :, p))
mincosts[p] = 0
end
end
Expand Down Expand Up @@ -240,21 +239,11 @@ function initseeds_by_costs!(iseeds::AbstractVector{<:Integer}, alg::KmCentralit
k = length(iseeds)
check_seeding_args(n, k)

# compute score for each item
coefs = vec(sum(costs, dims=2))
for i = 1:n
@inbounds coefs[i] = inv(coefs[i])
end

# scores[j] = \sum_j costs[i,j] / (\sum_{j'} costs[i,j'])
# = costs[i,j] * coefs[i]
scores = costs'coefs
scores = costs'vec(mapslices(invsum, costs, dims=2))

# lower score indicates better seeds
sp = sortperm(scores)
for i = 1:k
@inbounds iseeds[i] = sp[i]
end
copyto!(iseeds, 1, sortperm(scores), 1, k)
return iseeds
end

Expand Down
14 changes: 0 additions & 14 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,6 @@ display_level(s::Symbol) = get(DisplayLevels, s) do
throw(ArgumentError("Invalid option display=:$s ($(join(valid_vals, ", ", ", or ")) expected)"))
end

##### update minimum value

function updatemin!(r::AbstractArray, x::AbstractArray)
n = length(r)
length(x) == n || throw(DimensionMismatch("Inconsistent array lengths."))
@inbounds for i = 1:n
xi = x[i]
if xi < r[i]
r[i] = xi
end
end
return r
end

function check_assignments(assignments::AbstractVector{<:Integer}, nclusters::Union{Integer, Nothing})
nclu = nclusters === nothing ? maximum(assignments) : nclusters
for (j, c) in enumerate(assignments)
Expand Down

0 comments on commit 0bade2b

Please sign in to comment.