Skip to content

Commit 17214f1

Browse files
committed
Mini batch now fully compliant with API. Optimizations left
1 parent 04d63d5 commit 17214f1

File tree

3 files changed

+58
-68
lines changed

3 files changed

+58
-68
lines changed

src/kmeans.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,20 +170,30 @@ alternatively one can use `rand` to choose random points for init.
170170
171171
A `KmeansResult` structure representing labels, centroids, and sum_squares is returned.
172172
"""
173-
function kmeans(alg::AbstractKMeansAlg, design_matrix, k; weights = nothing,
173+
function kmeans(alg::AbstractKMeansAlg, design_matrix, k;
174+
weights = nothing,
174175
n_threads = Threads.nthreads(),
175-
k_init = "k-means++", max_iters = 300,
176-
tol = eltype(design_matrix)(1e-6), verbose = false,
177-
init = nothing, rng = Random.GLOBAL_RNG, metric = Euclidean())
176+
k_init = "k-means++",
177+
max_iters = 300,
178+
tol = eltype(design_matrix)(1e-6),
179+
verbose = false,
180+
init = nothing,
181+
rng = Random.GLOBAL_RNG,
182+
metric = Euclidean())
178183

179184
nrow, ncol = size(design_matrix)
180185

181186
# Create containers based on the dimensions and specifications
182187
containers = create_containers(alg, design_matrix, k, nrow, ncol, n_threads)
183188

184189
return kmeans!(alg, containers, design_matrix, k, weights, metric;
185-
n_threads = n_threads, k_init = k_init, max_iters = max_iters,
186-
tol = tol, verbose = verbose, init = init, rng = rng)
190+
n_threads = n_threads,
191+
k_init = k_init,
192+
max_iters = max_iters,
193+
tol = tol,
194+
verbose = verbose,
195+
init = init,
196+
rng = rng)
187197

188198
end
189199

src/mini_batch.jl

Lines changed: 36 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""
22
MiniBatch(b::Int)
3+
`b` represents the size of the batch which should be sampled.
34
45
Sculley et al. 2007 Mini batch k-means algorithm implementation.
56
@@ -16,8 +17,8 @@ end
1617

1718
MiniBatch() = MiniBatch(100)
1819

19-
function kmeans!(alg::MiniBatch, X, k;
20-
weights = nothing, metric = Euclidean(), n_threads = Threads.nthreads(),
20+
function kmeans!(alg::MiniBatch, containers, X, k,
21+
weights = nothing, metric = Euclidean(); n_threads = Threads.nthreads(),
2122
k_init = "k-means++", init = nothing, max_iters = 300,
2223
tol = eltype(X)(1e-6), max_no_improvement = 10, verbose = false, rng = Random.GLOBAL_RNG)
2324

@@ -32,15 +33,14 @@ function kmeans!(alg::MiniBatch, X, k;
3233
N = zeros(T, k)
3334

3435
# Initialize nearest centers for both batch and whole dataset labels
35-
final_labels = Vector{Int}(undef, ncol) # dataset labels
36-
3736
converged = false
3837
niters = 0
3938
counter = 0
4039
J_previous = zero(T)
4140
J = zero(T)
41+
totalcost = zero(T)
4242

43-
# TODO: Main Steps. Batch update centroids until convergence
43+
# Main Steps. Batch update centroids until convergence
4444
while niters <= max_iters # Step 4 in paper
4545

4646
# b examples picked randomly from X (Step 5 in paper)
@@ -57,16 +57,17 @@ function kmeans!(alg::MiniBatch, X, k;
5757
min_dist = dist < min_dist ? dist : min_dist
5858
end
5959

60-
final_labels[i] = label
60+
containers.labels[i] = label
6161
end
6262

63-
# TODO: Batch gradient step
63+
# Batch gradient step
6464
@inbounds for j in batch_rand_idx # iterate over examples (Step 9)
6565

6666
# Get cached center/label for this x => (Step 10)
67-
label = final_labels[j]
67+
label = containers.labels[j]
68+
6869
# Update per-center counts
69-
N[label] += isnothing(weights) ? 1 : weights[j] # verify (Step 11)
70+
N[label] += isnothing(weights) ? 1 : weights[j] # (Step 11)
7071

7172
# Get per-center learning rate (Step 12)
7273
lr = 1 / N[label]
@@ -75,65 +76,57 @@ function kmeans!(alg::MiniBatch, X, k;
7576
@views centroids[:, label] .= (1 - lr) .* centroids[:, label] .+ (lr .* X[:, j])
7677
end
7778

78-
# TODO: Reassign all labels based on new centres generated from the latest sample
79-
final_labels = reassign_labels(X, metric, final_labels, centroids)
79+
# Reassign all labels based on new centres generated from the latest sample
80+
containers.labels .= reassign_labels(X, metric, containers.labels, centroids)
8081

81-
# TODO: Calculate cost on whole dataset after reassignment and check for convergence
82-
J = sum_of_squares(X, final_labels, centroids) # just a placeholder for now
82+
# Calculate cost on whole dataset after reassignment and check for convergence
83+
@parallelize 1 ncol sum_of_squares(containers, X, containers.labels, centroids, weights, metric)
84+
J = sum(containers.sum_of_squares)
8385

8486
if verbose
8587
# Show progress and terminate if J stopped decreasing.
8688
println("Iteration $niters: Jclust = $J")
8789
end
8890

89-
# TODO: Check for early stopping convergence
91+
# Check for early stopping convergence
9092
if (niters > 1) & (abs(J - J_previous) < (tol * J))
9193
counter += 1
9294

9395
# Declare convergence if max_no_improvement criterion is met
9496
if counter >= max_no_improvement
9597
converged = true
96-
# TODO: Compute label assignment for the complete dataset
97-
final_labels = reassign_labels(X, metric, final_labels, centroids)
98+
# Compute label assignment for the complete dataset
99+
containers.labels .= reassign_labels(X, metric, containers.labels, centroids)
98100

99-
# TODO: Compute totalcost for the complete dataset
100-
J = sum_of_squares(X, final_labels, centroids) # just a placeholder for now
101+
# Compute totalcost for the complete dataset
102+
@parallelize 1 ncol sum_of_squares(containers, X, containers.labels, centroids, weights, metric)
103+
totalcost = sum(containers.sum_of_squares)
101104
break
102105
end
103106
else
104107
counter = 0
105108
end
106109

107-
# TODO: Warn users if model doesn't converge at max iterations
110+
# Warn users if model doesn't converge at max iterations
108111
if (niters > max_iters) & (!converged)
109112

110113
println("Clustering model failed to converge. Labelling data with latest centroids.")
111-
final_labels = reassign_labels(X, metric, final_labels, centroids)
114+
containers.labels = reassign_labels(X, metric, containers.labels, centroids)
112115

113-
# TODO: Compute totalcost for unconverged model
114-
J = sum_of_squares(X, final_labels, centroids)
116+
# Compute totalcost for unconverged model
117+
@parallelize 1 ncol sum_of_squares(containers, X, containers.labels, centroids, weights, metric)
118+
totalcost = sum(containers.sum_of_squares)
115119
break
116120
end
117121

118122
J_previous = J
119123
niters += 1
120124
end
121125

122-
return centroids, niters, converged, final_labels, J # TODO: push learned artifacts to KmeansResult
123-
#return KmeansResult(centroids, containers.labels, T[], Int[], T[], totalcost, niters, converged)
126+
# Push learned artifacts to KmeansResult
127+
return KmeansResult(centroids, containers.labels, T[], Int[], T[], totalcost, niters, converged)
124128
end
125129

126-
# TODO: Only being used to test generic implementation. Get rid off after!
127-
function sum_of_squares(x, labels, centre)
128-
s = 0.0
129-
130-
for i in axes(x, 2)
131-
for j in axes(x, 1)
132-
s += (x[j, i] - centre[j, labels[i]])^2
133-
end
134-
end
135-
return s
136-
end
137130

138131
function reassign_labels(DMatrix, metric, labels, centres)
139132
@inbounds for i in axes(DMatrix, 2)
@@ -159,25 +152,16 @@ Internal function for the creation of all necessary intermidiate structures.
159152
- `centroids_new` - container which holds new positions of centroids
160153
- `centroids_cnt` - container which holds number of points for each centroid
161154
- `labels` - vector which holds labels of corresponding points
155+
- `sum_of_squares` - vector which holds the sum of squares values for each thread
162156
"""
163157
function create_containers(::MiniBatch, X, k, nrow, ncol, n_threads)
164-
T = eltype(X)
165-
lng = n_threads + 1
166-
centroids_new = Vector{Matrix{T}}(undef, lng)
167-
centroids_cnt = Vector{Vector{T}}(undef, lng)
168-
169-
for i in 1:lng
170-
centroids_new[i] = Matrix{T}(undef, nrow, k)
171-
centroids_cnt[i] = Vector{Int}(undef, k)
172-
end
173-
174-
labels = Vector{Int}(undef, ncol)
175-
176-
J = Vector{T}(undef, n_threads)
177-
178-
# total_sum_calculation
179-
sum_of_squares = Vector{T}(undef, n_threads)
158+
# Initiate placeholders to avoid allocations
159+
T = eltype(X)
160+
centroids_new = Matrix{T}(undef, nrow, k) # main centroids
161+
centroids_cnt = Vector{T}(undef, k) # centroids counter
162+
labels = Vector{Int}(undef, ncol) # labels vector
163+
sum_of_squares = Vector{T}(undef, 1) # total_sum_calculation
180164

181165
return (centroids_new = centroids_new, centroids_cnt = centroids_cnt,
182-
labels = labels, J = J, sum_of_squares = sum_of_squares)
166+
labels = labels, sum_of_squares = sum_of_squares)
183167
end

test/test90_minibatch.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ end
1616
rng = StableRNG(2020)
1717
X = rand(rng, 3, 100)
1818

19-
baseline = [kmeans(Lloyd(), X, 2).totalcost for i in 1:1_000] |> mean |> round
20-
# TODO: Switch to kmeans after full implementation
21-
res = [ParallelKMeans.kmeans!(MiniBatch(50), X, 2)[end] for i in 1:1_000] |> mean |> round
19+
baseline = [kmeans(Lloyd(), X, 2; max_iters=100_000).totalcost for i in 1:200] |> mean |> round
20+
21+
res = [kmeans(MiniBatch(10), X, 2; max_iters=100_000).totalcost for i in 1:200] |> mean |> round
2222

2323
@test baseline == res
2424
end
@@ -28,13 +28,9 @@ end
2828
rng = StableRNG(2020)
2929
X = rand(rng, 3, 100)
3030

31-
baseline = [kmeans(Lloyd(), X, 2;
32-
tol=1e-6, metric=Cityblock(),
33-
max_iters=500).totalcost for i in 1:1000] |> mean |> floor
34-
# TODO: Switch to kmeans after full implementation
35-
res = [ParallelKMeans.kmeans!(MiniBatch(), X, 2;
36-
metric=Cityblock(), tol=1e-6,
37-
max_iters=500)[end] for i in 1:1000] |> mean |> floor
31+
baseline = [kmeans(Lloyd(), X, 2; metric=Cityblock(), max_iters=100_000).totalcost for i in 1:200] |> mean |> round
32+
33+
res = [kmeans(MiniBatch(10), X, 2; metric=Cityblock(), max_iters=100_000).totalcost for i in 1:200] |> mean |> round
3834

3935
@test baseline == res
4036
end

0 commit comments

Comments
 (0)