Skip to content

Commit ce3eeeb

Browse files
committed
Optimised implementation based on initial feedback. Allocations down & speed gains
1 parent 67add44 commit ce3eeeb

File tree

2 files changed

+15
-16
lines changed

2 files changed

+15
-16
lines changed

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ r.converged # whether the procedure converged
127127
- [Elkan()](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf) - Recommended for high dimensional data.
128128
- [Yinyang()](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/ding15.pdf) - Recommended for large dimensions and/or large number of clusters.
129129
- [Coreset()](http://proceedings.mlr.press/v51/lucic16-supp.pdf) - Recommended for very fast clustering of very large datasets, when extreme accuracy is not important.
130+
- [MiniBatch()](https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf) - Recommended for extremely large datasets.
130131
- [Geometric()](http://cs.baylor.edu/~hamerly/papers/sdm2016_rysavy_hamerly.pdf) - (Coming soon)
131-
- [MiniBatch()](https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf) - (Coming soon)
132132

133133
### Practical Usage Examples
134134

src/mini_batch.jl

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,13 @@ function kmeans!(alg::MiniBatch, containers, X, k,
3939
J_previous = zero(T)
4040
J = zero(T)
4141
totalcost = zero(T)
42-
42+
batch_rand_idx = containers.batch_rand_idx
43+
4344
# Main Steps. Batch update centroids until convergence
4445
while niters <= max_iters # Step 4 in paper
4546

4647
# b examples picked randomly from X (Step 5 in paper)
47-
batch_rand_idx = isnothing(weights) ? rand(rng, 1:ncol, alg.b) : wsample(rng, 1:ncol, weights, alg.b)
48+
batch_rand_idx = isnothing(weights) ? rand!(rng, batch_rand_idx, 1:ncol) : wsample!(rng, 1:ncol, weights, batch_rand_idx)
4849

4950
# Cache/label the batch samples nearest to the centers (Step 6 & 7)
5051
@inbounds for i in batch_rand_idx
@@ -58,22 +59,19 @@ function kmeans!(alg::MiniBatch, containers, X, k,
5859
end
5960

6061
containers.labels[i] = label
61-
end
6262

63-
# Batch gradient step
64-
@inbounds for j in batch_rand_idx # iterate over examples (Step 9)
65-
66-
# Get cached center/label for this x => (Step 10)
67-
label = containers.labels[j]
63+
##### Batch gradient step #####
64+
# iterate over examples (each column) ==> (Step 9)
65+
# Get cached center/label for each example label = labels[i] => (Step 10)
6866

6967
# Update per-center counts
70-
N[label] += isnothing(weights) ? 1 : weights[j] # (Step 11)
68+
N[label] += isnothing(weights) ? 1 : weights[i] # (Step 11)
7169

7270
# Get per-center learning rate (Step 12)
7371
lr = 1 / N[label]
7472

7573
# Take gradient step (Step 13) # TODO: Replace with faster loop?
76-
@views centroids[:, label] .= (1 - lr) .* centroids[:, label] .+ (lr .* X[:, j])
74+
@views centroids[:, label] .= (1 - lr) .* centroids[:, label] .+ (lr .* X[:, i])
7775
end
7876

7977
# Reassign all labels based on new centres generated from the latest sample
@@ -110,7 +108,9 @@ function kmeans!(alg::MiniBatch, containers, X, k,
110108
# Warn users if model doesn't converge at max iterations
111109
if (niters > max_iters) & (!converged)
112110

113-
println("Clustering model failed to converge. Labelling data with latest centroids.")
111+
if verbose
112+
println("Clustering model failed to converge. Labelling data with latest centroids.")
113+
end
114114
containers.labels = reassign_labels(X, metric, containers.labels, centroids)
115115

116116
# Compute totalcost for unconverged model
@@ -154,14 +154,13 @@ Internal function for the creation of all necessary intermidiate structures.
154154
- `labels` - vector which holds labels of corresponding points
155155
- `sum_of_squares` - vector which holds the sum of squares values for each thread
156156
"""
157-
function create_containers(::MiniBatch, X, k, nrow, ncol, n_threads)
157+
function create_containers(alg::MiniBatch, X, k, nrow, ncol, n_threads)
158158
# Initiate placeholders to avoid allocations
159159
T = eltype(X)
160-
centroids_new = Matrix{T}(undef, nrow, k) # main centroids
161-
centroids_cnt = Vector{T}(undef, k) # centroids counter
162160
labels = Vector{Int}(undef, ncol) # labels vector
163161
sum_of_squares = Vector{T}(undef, 1) # total_sum_calculation
162+
batch_rand_idx = Vector{Int}(undef, alg.b)
164163

165-
return (centroids_new = centroids_new, centroids_cnt = centroids_cnt,
164+
return (batch_rand_idx = batch_rand_idx,
166165
labels = labels, sum_of_squares = sum_of_squares)
167166
end

0 commit comments

Comments
 (0)