Skip to content

Commit d05b648

Browse files
committed
Added create_containers to minibatch
1 parent c52ff4f commit d05b648

File tree

2 files changed

+57
-11
lines changed

2 files changed

+57
-11
lines changed

src/lloyd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Basic algorithm for k-means calculation.
66
struct Lloyd <: AbstractKMeansAlg end
77

88
"""
9-
Kmeans!(alg::AbstractKMeansAlg, containers, design_matrix, k; n_threads = nthreads(), k_init="k-means++", max_iters=300, tol=1e-6, verbose=true)
9+
kmeans!(alg::AbstractKMeansAlg, containers, design_matrix, k; n_threads = nthreads(), k_init="k-means++", max_iters=300, tol=1e-6, verbose=true)
1010
1111
Mutable version of `kmeans` function. Definition of arguments and results can be
1212
found in `kmeans`.

src/mini_batch.jl

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
MiniBatch(b::Int)
33
44
Sculley et al. 2007 Mini batch k-means algorithm implementation.
5+
6+
```julia
7+
X = rand(30, 100_000) # 100_000 random points in 30 dimensions
8+
9+
kmeans(MiniBatch(100), X, 3) # 3 clusters, MiniBatch algorithm with 100 batch samples at each iteration
10+
```
511
"""
612
struct MiniBatch <: AbstractKMeansAlg
713
b::Int # batch size
@@ -39,35 +45,34 @@ function kmeans!(alg::MiniBatch, X, k;
3945

4046
# b examples picked randomly from X (Step 5 in paper)
4147
batch_rand_idx = isnothing(weights) ? rand(rng, 1:ncol, alg.b) : wsample(rng, 1:ncol, weights, alg.b)
42-
batch_sample = X[:, batch_rand_idx]
4348

4449
# Cache/label the batch samples nearest to the centers (Step 6 & 7)
45-
@inbounds for i in axes(batch_sample, 2)
46-
min_dist = distance(metric, batch_sample, centroids, i, 1)
50+
@inbounds for i in batch_rand_idx
51+
min_dist = distance(metric, X, centroids, i, 1)
4752
label = 1
4853

4954
for j in 2:size(centroids, 2)
50-
dist = distance(metric, batch_sample, centroids, i, j)
55+
dist = distance(metric, X, centroids, i, j)
5156
label = dist < min_dist ? j : label
5257
min_dist = dist < min_dist ? dist : min_dist
5358
end
5459

55-
final_labels[batch_rand_idx[i]] = label
60+
final_labels[i] = label
5661
end
5762

5863
# TODO: Batch gradient step
59-
@inbounds for j in axes(batch_sample, 2) # iterate over examples (Step 9)
64+
@inbounds for j in batch_rand_idx # iterate over examples (Step 9)
6065

61-
# Get cached center/label for this x => labels[batch_rand_idx[j]] (Step 10)
62-
label = final_labels[batch_rand_idx[j]]
66+
# Get cached center/label for this x => (Step 10)
67+
label = final_labels[j]
6368
# Update per-center counts
6469
N[label] += isnothing(weights) ? 1 : weights[j] # verify (Step 11)
6570

6671
# Get per-center learning rate (Step 12)
6772
lr = 1 / N[label]
6873

69-
# Take gradient step (Step 13) # TODO: Replace with an allocation-less loop.
70-
centroids[:, label] .= (1 - lr) .* centroids[:, label] .+ (lr .* batch_sample[:, j])
74+
# Take gradient step (Step 13) # TODO: Replace with faster loop?
75+
@views centroids[:, label] .= (1 - lr) .* centroids[:, label] .+ (lr .* X[:, j])
7176
end
7277

7378
# TODO: Reassign all labels based on new centres generated from the latest sample
@@ -97,7 +102,17 @@ function kmeans!(alg::MiniBatch, X, k;
97102
end
98103
else
99104
counter = 0
105+
end
106+
107+
# TODO: Warn users if model doesn't converge at max iterations
108+
if (niters > max_iters) & (!converged)
109+
110+
println("Clustering model failed to converge. Labelling data with latest centroids.")
111+
final_labels = reassign_labels(X, metric, final_labels, centroids)
100112

113+
# TODO: Compute totalcost for unconverged model
114+
J = sum_of_squares(X, final_labels, centroids)
115+
break
101116
end
102117

103118
J_previous = J
@@ -135,3 +150,34 @@ function reassign_labels(DMatrix, metric, labels, centres)
135150
end
136151
return labels
137152
end
153+
154+
"""
155+
create_containers(::MiniBatch, k, nrow, ncol, n_threads)
156+
157+
Internal function for the creation of all necessary intermidiate structures.
158+
159+
- `centroids_new` - container which holds new positions of centroids
160+
- `centroids_cnt` - container which holds number of points for each centroid
161+
- `labels` - vector which holds labels of corresponding points
162+
"""
163+
function create_containers(::MiniBatch, X, k, nrow, ncol, n_threads)
164+
T = eltype(X)
165+
lng = n_threads + 1
166+
centroids_new = Vector{Matrix{T}}(undef, lng)
167+
centroids_cnt = Vector{Vector{T}}(undef, lng)
168+
169+
for i in 1:lng
170+
centroids_new[i] = Matrix{T}(undef, nrow, k)
171+
centroids_cnt[i] = Vector{Int}(undef, k)
172+
end
173+
174+
labels = Vector{Int}(undef, ncol)
175+
176+
J = Vector{T}(undef, n_threads)
177+
178+
# total_sum_calculation
179+
sum_of_squares = Vector{T}(undef, n_threads)
180+
181+
return (centroids_new = centroids_new, centroids_cnt = centroids_cnt,
182+
labels = labels, J = J, sum_of_squares = sum_of_squares)
183+
end

0 commit comments

Comments
 (0)