Skip to content

Commit d9e06e2

Browse files
author
Andrey Oskin
committed
Elkan and refactoring
1 parent a44f676 commit d9e06e2

File tree

9 files changed

+546
-109
lines changed

9 files changed

+546
-109
lines changed

src/ParallelKMeans.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@ import Distances
77

88
include("seeding.jl")
99
include("kmeans.jl")
10-
include("lloyd.jl")
1110
include("light_elkan.jl")
11+
include("lloyd.jl")
1212
include("hamerly.jl")
13+
include("elkan.jl")
1314
include("mlj_interface.jl")
1415

1516
export kmeans
16-
export Lloyd, LightElkan, Hamerly
17+
export Lloyd, LightElkan, Hamerly, Elkan
1718

1819
end # module

src/elkan.jl

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
"""
2+
Elkan()
3+
4+
Elkan algorithm implementation, based on "Charles Elkan. 2003.
5+
Using the triangle inequality to accelerate k-means.
6+
In Proceedings of the Twentieth International Conference on
7+
International Conference on Machine Learning (ICML’03). AAAI Press, 147–153."
8+
9+
This algorithm provides much faster convergence than Lloyd algorithm especially
10+
for high dimensional data.
11+
It can be used directly in `kmeans` function
12+
13+
```julia
14+
X = rand(30, 100_000) # 100_000 random points in 30 dimensions
15+
16+
kmeans(Elkan(), X, 3) # 3 clusters, Elkan algorithm
17+
```
18+
"""
19+
struct Elkan <: AbstractKMeansAlg end
20+
21+
function kmeans!(alg::Elkan, containers, X, k;
22+
n_threads = Threads.nthreads(),
23+
k_init = "k-means++", max_iters = 300,
24+
tol = 1e-6, verbose = false, init = nothing)
25+
nrow, ncol = size(X)
26+
centroids = init == nothing ? smart_init(X, k, n_threads, init=k_init).centroids : deepcopy(init)
27+
28+
update_containers(alg, containers, centroids, n_threads)
29+
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X)
30+
31+
converged = false
32+
niters = 1
33+
J_previous = 0.0
34+
35+
# Update centroids & labels with closest members until convergence
36+
while niters <= max_iters
37+
# Core iteration
38+
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X)
39+
40+
# Collect distributed containers (such as centroids_new, centroids_cnt)
41+
# in paper it is step 4
42+
collect_containers(alg, containers, n_threads)
43+
44+
J = sum(containers.ub)
45+
46+
# auxiliary calculation, in paper it's d(c, m(c))
47+
calculate_centroids_movement(alg, containers, centroids)
48+
49+
# lower and ounds update, in paper it's steps 5 and 6
50+
@parallelize n_threads ncol chunk_update_bounds(alg, containers, centroids)
51+
52+
# Step 7, final assignment of new centroids
53+
centroids .= containers.centroids_new[end]
54+
55+
if verbose
56+
# Show progress and terminate if J stopped decreasing.
57+
println("Iteration $niters: Jclust = $J")
58+
end
59+
60+
# Check for convergence
61+
if (niters > 1) & (abs(J - J_previous) < (tol * J))
62+
converged = true
63+
break
64+
end
65+
66+
# Step 1 in original paper, calulation of distance d(c, c')
67+
update_containers(alg, containers, centroids, n_threads)
68+
J_previous = J
69+
niters += 1
70+
end
71+
72+
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids)
73+
totalcost = sum(containers.sum_of_squares)
74+
75+
# Terminate algorithm with the assumption that K-means has converged
76+
if verbose & converged
77+
println("Successfully terminated with convergence.")
78+
end
79+
80+
# TODO empty placeholder vectors should be calculated
81+
# TODO Float64 type definitions is too restrictive, should be relaxed
82+
# especially during GPU related development
83+
return KmeansResult(centroids, containers.labels, Float64[], Int[], Float64[], totalcost, niters, converged)
84+
end
85+
86+
function create_containers(::Elkan, k, nrow, ncol, n_threads)
87+
lng = n_threads + 1
88+
centroids_new = Vector{Array{Float64,2}}(undef, lng)
89+
centroids_cnt = Vector{Vector{Int}}(undef, lng)
90+
91+
for i = 1:lng
92+
centroids_new[i] = zeros(nrow, k)
93+
centroids_cnt[i] = zeros(k)
94+
end
95+
96+
centroids_dist = Matrix{Float64}(undef, k, k)
97+
98+
# lower bounds
99+
lb = Matrix{Float64}(undef, k, ncol)
100+
101+
# upper bounds
102+
ub = Vector{Float64}(undef, ncol)
103+
104+
# r(x) in original paper, shows whether point distance should be updated
105+
stale = ones(Bool, ncol)
106+
107+
# distance that centroid moved
108+
p = Vector{Float64}(undef, k)
109+
110+
labels = zeros(Int, ncol)
111+
112+
# total_sum_calculation
113+
sum_of_squares = Vector{Float64}(undef, n_threads)
114+
115+
return (
116+
centroids_new = centroids_new,
117+
centroids_cnt = centroids_cnt,
118+
labels = labels,
119+
centroids_dist = centroids_dist,
120+
lb = lb,
121+
ub = ub,
122+
stale = stale,
123+
p = p,
124+
sum_of_squares = sum_of_squares
125+
)
126+
end
127+
128+
function chunk_initialize(::Elkan, containers, centroids, X, r, idx)
129+
ub = containers.ub
130+
lb = containers.lb
131+
centroids_dist = containers.centroids_dist
132+
labels = containers.labels
133+
centroids_new = containers.centroids_new[idx]
134+
centroids_cnt = containers.centroids_cnt[idx]
135+
136+
@inbounds for i in r
137+
min_dist = distance(X, centroids, i, 1)
138+
label = 1
139+
lb[label, i] = min_dist
140+
for j in 2:size(centroids, 2)
141+
# triangular inequality
142+
if centroids_dist[j, label] > min_dist
143+
lb[j, i] = min_dist
144+
else
145+
dist = distance(X, centroids, i, j)
146+
label = dist < min_dist ? j : label
147+
min_dist = dist < min_dist ? dist : min_dist
148+
lb[j, i] = dist
149+
end
150+
end
151+
ub[i] = min_dist
152+
labels[i] = label
153+
centroids_cnt[label] += 1
154+
for j in axes(X, 1)
155+
centroids_new[j, label] += X[j, i]
156+
end
157+
end
158+
end
159+
160+
function update_containers(::Elkan, containers, centroids, n_threads)
161+
# unpack containers for easier manipulations
162+
centroids_dist = containers.centroids_dist
163+
164+
k = size(centroids_dist, 1) # number of clusters
165+
@inbounds for j in axes(centroids_dist, 2)
166+
min_dist = Inf
167+
for i in j + 1:k
168+
d = distance(centroids, centroids, i, j)
169+
centroids_dist[i, j] = d
170+
centroids_dist[j, i] = d
171+
min_dist = min_dist < d ? min_dist : d
172+
end
173+
for i in 1:j - 1
174+
min_dist = min_dist < centroids_dist[j, i] ? min_dist : centroids_dist[j, i]
175+
end
176+
centroids_dist[j, j] = min_dist
177+
end
178+
179+
# TODO: oh, one should be careful here. inequality holds for eucledian metrics
180+
# not square eucledian. So, for Lp norm it should be something like
181+
# centroids_dist = 0.5^p. Should check one more time original paper
182+
centroids_dist .*= 0.25
183+
184+
return centroids_dist
185+
end
186+
187+
function chunk_update_centroids(::Elkan, containers, centroids, X, r, idx)
188+
# unpack
189+
ub = containers.ub
190+
lb = containers.lb
191+
centroids_dist = containers.centroids_dist
192+
labels = containers.labels
193+
stale = containers.stale
194+
centroids_new = containers.centroids_new[idx]
195+
centroids_cnt = containers.centroids_cnt[idx]
196+
197+
@inbounds for i in r
198+
label_old = labels[i]
199+
label = label_old
200+
min_dist = ub[i]
201+
# tighten the loop, exclude points that very close to center
202+
min_dist <= centroids_dist[label, label] && continue
203+
for j in axes(centroids, 2)
204+
# tighten the loop once more, exclude far away centers
205+
j == label && continue
206+
min_dist <= lb[j, i] && continue
207+
min_dist <= centroids_dist[j, label] && continue
208+
209+
# one calculation per iteration is enough
210+
if stale[i]
211+
min_dist = distance(X, centroids, i, label)
212+
lb[label, i] = min_dist
213+
ub[i] = min_dist
214+
stale[i] = false
215+
end
216+
217+
if (min_dist > lb[j, i]) | (min_dist > centroids_dist[j, label])
218+
dist = distance(X, centroids, i, j)
219+
lb[j, i] = dist
220+
if dist < min_dist
221+
min_dist = dist
222+
label = j
223+
end
224+
end
225+
end
226+
227+
if label != label_old
228+
labels[i] = label
229+
centroids_cnt[label_old] -= 1
230+
centroids_cnt[label] += 1
231+
for j in axes(X, 1)
232+
centroids_new[j, label_old] -= X[j, i]
233+
centroids_new[j, label] += X[j, i]
234+
end
235+
end
236+
end
237+
end
238+
239+
function collect_containers(alg::Elkan, containers, n_threads)
240+
if n_threads == 1
241+
@inbounds containers.centroids_new[end] .= containers.centroids_new[1] ./ containers.centroids_cnt[1]'
242+
else
243+
@inbounds containers.centroids_new[end] .= containers.centroids_new[1]
244+
@inbounds containers.centroids_cnt[end] .= containers.centroids_cnt[1]
245+
@inbounds for i in 2:n_threads
246+
containers.centroids_new[end] .+= containers.centroids_new[i]
247+
containers.centroids_cnt[end] .+= containers.centroids_cnt[i]
248+
end
249+
250+
@inbounds containers.centroids_new[end] .= containers.centroids_new[end] ./ containers.centroids_cnt[end]'
251+
end
252+
end
253+
254+
function calculate_centroids_movement(alg::Elkan, containers, centroids)
255+
p = containers.p
256+
centroids_new = containers.centroids_new[end]
257+
258+
for i in axes(centroids, 2)
259+
p[i] = distance(centroids, centroids_new, i, i)
260+
end
261+
end
262+
263+
264+
function chunk_update_bounds(alg, containers, centroids, r, idx)
265+
p = containers.p
266+
lb = containers.lb
267+
ub = containers.ub
268+
stale = containers.stale
269+
labels = containers.labels
270+
271+
@inbounds for i in r
272+
for j in axes(centroids, 2)
273+
lb[j, i] = lb[j, i] > p[j] ? lb[j, i] + p[j] - 2*sqrt(abs(lb[j, i]*p[j])) : 0.0
274+
end
275+
stale[i] = true
276+
ub[i] += p[labels[i]] + 2*sqrt(abs(ub[i]*p[labels[i]]))
277+
end
278+
end

0 commit comments

Comments
 (0)