Skip to content

Commit 8049b0d

Browse files
author
Andrey Oskin
committed
Coresets implementation
1 parent e95461b commit 8049b0d

File tree

6 files changed

+241
-83
lines changed

6 files changed

+241
-83
lines changed

src/ParallelKMeans.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,16 @@ import Distances
77

88
const MMI = MLJModelInterface
99

10-
include("seeding.jl")
1110
include("kmeans.jl")
11+
include("seeding.jl")
1212
include("lloyd.jl")
1313
include("hamerly.jl")
1414
include("elkan.jl")
1515
include("yinyang.jl")
1616
include("mlj_interface.jl")
17+
include("coreset.jl")
1718

1819
export kmeans
19-
export Lloyd, Hamerly, Elkan, Yinyang
20+
export Lloyd, Hamerly, Elkan, Yinyang, Coreset
2021

2122
end # module

src/coreset.jl

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""
2+
Coreset()
3+
4+
Coreset algorithm implementation, based on "Lucic, Mario & Bachem,
5+
Olivier & Krause, Andreas. (2015). Strong Coresets for Hard and Soft Bregman
6+
Clustering with Applications to Exponential Family Mixtures. "
7+
8+
It can be used directly in `kmeans` function
9+
10+
```julia
11+
X = rand(30, 100_000) # 100_000 random points in 30 dimensions
12+
13+
kmeans(Coreset(), X, 3) # 3 clusters, Coreset algorithm
14+
```
15+
"""
16+
struct Coreset{T <: AbstractKMeansAlg} <: AbstractKMeansAlg
17+
m::Int
18+
alg::T
19+
end
20+
21+
Coreset() = Coreset(100, Lloyd())
22+
23+
function kmeans!(alg::Coreset, containers, X, k, weights;
24+
n_threads = Threads.nthreads(),
25+
k_init = "k-means++", max_iters = 300,
26+
tol = eltype(design_matrix)(1e-6), verbose = false, init = nothing)
27+
nrow, ncol = size(X)
28+
centroids = isnothing(init) ? smart_init(X, k, n_threads, init=k_init).centroids : deepcopy(init)
29+
30+
T = eltype(X)
31+
# Steps 2-4 of the paper's algorithm 3
32+
# We distribute points over the centers and calculate weights of each cluster
33+
@parallelize n_threads ncol chunk_fit(alg, containers, centroids, X, weights)
34+
35+
# after this step, containers.centroids_new
36+
collect_containers(alg, containers, n_threads)
37+
38+
# step 7 of the algorithm 3
39+
@parallelize n_threads ncol chunk_update_sensitivity(alg, containers)
40+
41+
# sample from containers.s
42+
coreset_ids = wsample(1:ncol, containers.s, alg.m)
43+
coreset = X[:, coreset_ids]
44+
# create new weights as 1/s[i]
45+
coreset_weights = one(T) ./ @view containers.s[coreset_ids]
46+
47+
# run usual kmeans for new set with new weights.
48+
res = kmeans(alg.alg, coreset, k, coreset_weights, tol = tol, max_iters = max_iters,
49+
verbose = verbose, init = centroids, n_threads = n_threads)
50+
51+
@parallelize n_threads ncol chunk_apply(alg, containers, res.centers, X, weights)
52+
53+
totalcost = sum(containers.totalcost)
54+
55+
return KmeansResult(res.centers, containers.labels, T[], Int[], T[], totalcost, res.iterations, res.converged)
56+
end
57+
58+
function create_containers(alg::Coreset, X, k, nrow, ncol, n_threads)
59+
T = eltype(X)
60+
61+
centroids_cnt = Vector{Vector{T}}(undef, n_threads)
62+
centroids_dist = Vector{Vector{T}}(undef, n_threads)
63+
64+
# sensitivity
65+
66+
for i in 1:n_threads
67+
centroids_cnt[i] = zeros(T, k)
68+
centroids_dist[i] = zeros(T, k)
69+
end
70+
71+
labels = Vector{Int}(undef, ncol)
72+
s = Vector{T}(undef, ncol)
73+
74+
# J is the same as $c_\phi$ in the paper.
75+
J = Vector{T}(undef, n_threads)
76+
77+
alpha = 16 * (log(k) + 2)
78+
79+
centroids_const = Vector{T}(undef, k)
80+
81+
# total_sum_calculation
82+
totalcost = Vector{T}(undef, n_threads)
83+
84+
return (
85+
centroids_cnt = centroids_cnt,
86+
centroids_dist = centroids_dist,
87+
s = s,
88+
labels = labels,
89+
totalcost = totalcost,
90+
J = J,
91+
centroids_const = centroids_const,
92+
alpha = alpha
93+
)
94+
end
95+
96+
function chunk_fit(alg::Coreset, containers, centroids, X, weights, r, idx)
97+
centroids_cnt = containers.centroids_cnt[idx]
98+
centroids_dist = containers.centroids_dist[idx]
99+
labels = containers.labels
100+
s = containers.s
101+
T = eltype(X)
102+
103+
J = zero(T)
104+
for i in r
105+
dist = distance(X, centroids, i, 1)
106+
label = 1
107+
for j in 2:size(centroids, 2)
108+
new_dist = distance(X, centroids, i, j)
109+
110+
# calculation of the closest center (steps 2-3 of the paper's algorithm 3)
111+
label = new_dist < dist ? j : label
112+
dist = new_dist < dist ? new_dist : dist
113+
end
114+
labels[i] = label
115+
116+
# calculation of the $c_\phi$ (step 4)
117+
# Note: $d_A(x', B) = min_{b \in B} d_A(x', b)$
118+
# Not exactly sure about whole `weights` thing, needs further investigation
119+
# for Nothing `weights` (default) it'll work as intendent
120+
centroids_cnt[label] += isnothing(weights) ? one(T) : weights[i]
121+
centroids_dist[label] += isnothing(weights) ? dist : weights[i] * dist
122+
J += dist
123+
124+
# for now we write dist to sensitivity, update it later
125+
s[i] = dist
126+
end
127+
128+
containers.J[idx] = J
129+
end
130+
131+
function collect_containers(::Coreset, containers, n_threads)
132+
# Here we transform formula of the step 6
133+
# By multiplying both sides of equation on $c_\phi / \alpha$ we obtain
134+
# $s(x) <- d_A(x, B) + 2 \sum d_A(x, B) / |B_i| + 4 c_\phi |\Chi| / (|B_i| * \alpha)$
135+
# Taking into account that $c_\phi = 1/|\Chi| \sum d_A(x', B) = J / |\Chi|$ we get
136+
# $s(x) <- d_A(x, B) + 2 \sum d_A(x, B) / |B_i| + 4 J / \alpha * 1/ |B_i|$
137+
138+
alpha = containers.alpha
139+
centroids_const = containers.centroids_const
140+
141+
centroids_cnt = containers.centroids_cnt[1]
142+
centroids_dist = containers.centroids_dist[1]
143+
J = containers.J[1]
144+
145+
@inbounds for i in 2:n_threads
146+
centroids_cnt .+= containers.centroids_cnt[i]
147+
centroids_dist .+= containers.centroids_dist[i]
148+
J += containers.J[i]
149+
end
150+
151+
J = 4 * J / alpha
152+
153+
for i in 1:length(centroids_dist)
154+
centroids_const[i] = 2 * centroids_dist[i] / centroids_cnt[i] +
155+
J / centroids_cnt[i]
156+
end
157+
end
158+
159+
function chunk_update_sensitivity(alg::Coreset, containers, r, idx)
160+
labels = containers.labels
161+
centroids_const = containers.centroids_const
162+
s = containers.s
163+
164+
@inbounds for i in r
165+
s[i] += centroids_const[labels[i]]
166+
end
167+
end
168+
169+
function chunk_apply(alg::Coreset, containers, centroids, X, weights, r, idx)
170+
centroids_cnt = containers.centroids_cnt[idx]
171+
centroids_dist = containers.centroids_dist[idx]
172+
labels = containers.labels
173+
T = eltype(X)
174+
175+
J = zero(T)
176+
for i in r
177+
dist = distance(X, centroids, i, 1)
178+
label = 1
179+
for j in 2:size(centroids, 2)
180+
new_dist = distance(X, centroids, i, j)
181+
182+
# calculation of the closest center (steps 2-3 of the paper's algorithm 3)
183+
label = new_dist < dist ? j : label
184+
dist = new_dist < dist ? new_dist : dist
185+
end
186+
labels[i] = label
187+
J += isnothing(weights) ? dist : weights[i] * dist
188+
end
189+
190+
containers.totalcost[idx] = J
191+
end

src/hamerly.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ kmeans(Hamerly(), X, 3) # 3 clusters, Hamerly algorithm
1818
struct Hamerly <: AbstractKMeansAlg end
1919

2020

21-
function kmeans!(alg::Hamerly, containers, X, k;
21+
function kmeans!(alg::Hamerly, containers, X, k, weights;
2222
n_threads = Threads.nthreads(),
2323
k_init = "k-means++", max_iters = 300,
2424
tol = eltype(X)(1e-6), verbose = false, init = nothing)
2525
nrow, ncol = size(X)
2626
centroids = init == nothing ? smart_init(X, k, n_threads, init=k_init).centroids : deepcopy(init)
2727

28-
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X)
28+
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X, weights)
2929

3030
T = eltype(X)
3131
converged = false
@@ -37,7 +37,7 @@ function kmeans!(alg::Hamerly, containers, X, k;
3737
while niters < max_iters
3838
niters += 1
3939
update_containers(alg, containers, centroids, n_threads)
40-
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X)
40+
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X, weights)
4141
collect_containers(alg, containers, n_threads)
4242

4343
J = sum(containers.ub)
@@ -60,7 +60,7 @@ function kmeans!(alg::Hamerly, containers, X, k;
6060
J_previous = J
6161
end
6262

63-
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids)
63+
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids, weights)
6464
totalcost = sum(containers.sum_of_squares)
6565

6666
# Terminate algorithm with the assumption that K-means has converged
@@ -119,16 +119,16 @@ end
119119
120120
Initial calulation of all bounds and points labeling.
121121
"""
122-
function chunk_initialize(alg::Hamerly, containers, centroids, X, r, idx)
122+
function chunk_initialize(alg::Hamerly, containers, centroids, X, weights, r, idx)
123123
T = eltype(X)
124124
centroids_cnt = containers.centroids_cnt[idx]
125125
centroids_new = containers.centroids_new[idx]
126126

127127
@inbounds for i in r
128128
label = point_all_centers!(containers, centroids, X, i)
129-
centroids_cnt[label] += one(T)
129+
centroids_cnt[label] += isnothing(weights) ? one(T) : weights[i]
130130
for j in axes(X, 1)
131-
centroids_new[j, label] += X[j, i]
131+
centroids_new[j, label] += isnothing(weights) ? X[j, i] : weights[i] * X[j, i]
132132
end
133133
end
134134
end
@@ -159,7 +159,7 @@ Detailed description of this function can be found in the original paper. It ite
159159
all points and tries to skip some calculation using known upper and lower bounds of distances
160160
from point to centers. If it fails to skip than it fall back to generic `point_all_centers!` function.
161161
"""
162-
function chunk_update_centroids(alg::Hamerly, containers, centroids, X, r, idx)
162+
function chunk_update_centroids(alg::Hamerly, containers, centroids, X, weights, r, idx)
163163

164164
# unpack containers for easier manipulations
165165
centroids_new = containers.centroids_new[idx]
@@ -183,11 +183,11 @@ function chunk_update_centroids(alg::Hamerly, containers, centroids, X, r, idx)
183183
label_new = point_all_centers!(containers, centroids, X, i)
184184
if label != label_new
185185
labels[i] = label_new
186-
centroids_cnt[label_new] += one(T)
187-
centroids_cnt[label] -= one(T)
186+
centroids_cnt[label_new] += isnothing(weights) ? one(T) : weights[i]
187+
centroids_cnt[label] -= isnothing(weights) ? one(T) : weights[i]
188188
for j in axes(X, 1)
189-
centroids_new[j, label_new] += X[j, i]
190-
centroids_new[j, label] -= X[j, i]
189+
centroids_new[j, label_new] += isnothing(weights) ? X[j, i] : weights[i] * X[j, i]
190+
centroids_new[j, label] -= isnothing(weights) ? X[j, i] : weights[i] * X[j, i]
191191
end
192192
end
193193
end

src/kmeans.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,11 @@ design matrix(x), centroids (centre), and the number of desired groups (k).
109109
110110
A Float type representing the computed metric is returned.
111111
"""
112-
function sum_of_squares(containers, x, labels, centre, r, idx)
112+
function sum_of_squares(containers, x, labels, centre, weights, r, idx)
113113
s = zero(eltype(x))
114114

115-
@inbounds for j in r
116-
for i in axes(x, 1)
117-
s += (x[i, j] - centre[i, labels[j]])^2
118-
end
115+
@inbounds for i in r
116+
s += isnothing(weights) ? distance(x, centre, i, labels[i]) : weights[i] * distance(x, centre, i, labels[i])
119117
end
120118

121119
containers.sum_of_squares[idx] = s
@@ -148,14 +146,14 @@ alternatively one can use `rand` to choose random points for init.
148146
149147
A `KmeansResult` structure representing labels, centroids, and sum_squares is returned.
150148
"""
151-
function kmeans(alg, design_matrix, k;
149+
function kmeans(alg::AbstractKMeansAlg, design_matrix, k, weights = nothing;
152150
n_threads = Threads.nthreads(),
153151
k_init = "k-means++", max_iters = 300,
154152
tol = eltype(design_matrix)(1e-6), verbose = false, init = nothing)
155153
nrow, ncol = size(design_matrix)
156154
containers = create_containers(alg, design_matrix, k, nrow, ncol, n_threads)
157155

158-
return kmeans!(alg, containers, design_matrix, k, n_threads = n_threads,
156+
return kmeans!(alg, containers, design_matrix, k, weights, n_threads = n_threads,
159157
k_init = k_init, max_iters = max_iters, tol = tol,
160158
verbose = verbose, init = init)
161159
end

src/lloyd.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ found in `kmeans`.
1414
Argument `containers` represent algorithm specific containers, such as labels, intermidiate
1515
centroids and so on, which are used during calculations.
1616
"""
17-
function kmeans!(alg::Lloyd, containers, X, k;
17+
function kmeans!(alg::Lloyd, containers, X, k, weights;
1818
n_threads = Threads.nthreads(),
1919
k_init = "k-means++", max_iters = 300,
2020
tol = eltype(design_matrix)(1e-6), verbose = false, init = nothing)
@@ -28,7 +28,7 @@ function kmeans!(alg::Lloyd, containers, X, k;
2828

2929
# Update centroids & labels with closest members until convergence
3030
while niters <= max_iters
31-
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X)
31+
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X, weights)
3232
collect_containers(alg, containers, centroids, n_threads)
3333
J = sum(containers.J)
3434

@@ -47,7 +47,7 @@ function kmeans!(alg::Lloyd, containers, X, k;
4747
niters += 1
4848
end
4949

50-
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids)
50+
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids, weights)
5151
totalcost = sum(containers.sum_of_squares)
5252

5353
# Terminate algorithm with the assumption that K-means has converged
@@ -61,11 +61,11 @@ function kmeans!(alg::Lloyd, containers, X, k;
6161
return KmeansResult(centroids, containers.labels, T[], Int[], T[], totalcost, niters, converged)
6262
end
6363

64-
kmeans(design_matrix, k;
64+
kmeans(design_matrix, k, weights = nothing;
6565
n_threads = Threads.nthreads(),
6666
k_init = "k-means++", max_iters = 300, tol = 1e-6,
6767
verbose = false, init = nothing) =
68-
kmeans(Lloyd(), design_matrix, k; n_threads = n_threads, k_init = k_init, max_iters = max_iters, tol = tol,
68+
kmeans(Lloyd(), design_matrix, k, weights; n_threads = n_threads, k_init = k_init, max_iters = max_iters, tol = tol,
6969
verbose = verbose, init = init)
7070

7171
"""
@@ -99,7 +99,7 @@ function create_containers(::Lloyd, X, k, nrow, ncol, n_threads)
9999
labels = labels, J = J, sum_of_squares = sum_of_squares)
100100
end
101101

102-
function chunk_update_centroids(::Lloyd, containers, centroids, X, r, idx)
102+
function chunk_update_centroids(::Lloyd, containers, centroids, X, weights, r, idx)
103103
# unpack containers for easier manipulations
104104
centroids_new = containers.centroids_new[idx]
105105
centroids_cnt = containers.centroids_cnt[idx]
@@ -118,9 +118,9 @@ function chunk_update_centroids(::Lloyd, containers, centroids, X, r, idx)
118118
min_dist = dist < min_dist ? dist : min_dist
119119
end
120120
labels[i] = label
121-
centroids_cnt[label] += one(T)
121+
centroids_cnt[label] += isnothing(weights) ? one(T) : weights[i]
122122
for j in axes(X, 1)
123-
centroids_new[j, label] += X[j, i]
123+
centroids_new[j, label] += isnothing(weights) ? X[j, i] : weights[i] * X[j, i]
124124
end
125125
J += min_dist
126126
end

0 commit comments

Comments
 (0)