11"""
22 MiniBatch(b::Int)
3+ `b` represents the size of the batch which should be sampled.
34
45 Sculley et al. 2007 Mini batch k-means algorithm implementation.
6+
7+ ```julia
8+ X = rand(30, 100_000) # 100_000 random points in 30 dimensions
9+
10+ kmeans(MiniBatch(100), X, 3) # 3 clusters, MiniBatch algorithm with 100 batch samples at each iteration
11+ ```
512"""
613struct MiniBatch <: AbstractKMeansAlg
714 b:: Int # batch size
1017
1118MiniBatch () = MiniBatch (100 )
1219
13- function kmeans! (alg:: MiniBatch , X, k;
14- weights = nothing , metric = Euclidean (), n_threads = Threads. nthreads (),
20+ function kmeans! (alg:: MiniBatch , containers, X, k,
21+ weights = nothing , metric = Euclidean (); n_threads = Threads. nthreads (),
1522 k_init = " k-means++" , init = nothing , max_iters = 300 ,
1623 tol = eltype (X)(1e-6 ), max_no_improvement = 10 , verbose = false , rng = Random. GLOBAL_RNG)
1724
@@ -26,99 +33,100 @@ function kmeans!(alg::MiniBatch, X, k;
2633 N = zeros (T, k)
2734
2835 # Initialize nearest centers for both batch and whole dataset labels
29- final_labels = Vector {Int} (undef, ncol) # dataset labels
30-
3136 converged = false
3237 niters = 0
3338 counter = 0
3439 J_previous = zero (T)
3540 J = zero (T)
36-
37- # TODO : Main Steps. Batch update centroids until convergence
41+ totalcost = zero (T)
42+ batch_rand_idx = containers. batch_rand_idx
43+
44+ # Main Steps. Batch update centroids until convergence
3845 while niters <= max_iters # Step 4 in paper
3946
4047 # b examples picked randomly from X (Step 5 in paper)
41- batch_rand_idx = isnothing (weights) ? rand (rng, 1 : ncol, alg. b) : wsample (rng, 1 : ncol, weights, alg. b)
42- batch_sample = X[:, batch_rand_idx]
48+ batch_rand_idx = isnothing (weights) ? rand! (rng, batch_rand_idx, 1 : ncol) : wsample! (rng, 1 : ncol, weights, batch_rand_idx)
4349
4450 # Cache/label the batch samples nearest to the centers (Step 6 & 7)
45- @inbounds for i in axes (batch_sample, 2 )
46- min_dist = distance (metric, batch_sample , centroids, i, 1 )
51+ @inbounds for i in batch_rand_idx
52+ min_dist = distance (metric, X , centroids, i, 1 )
4753 label = 1
4854
4955 for j in 2 : size (centroids, 2 )
50- dist = distance (metric, batch_sample , centroids, i, j)
56+ dist = distance (metric, X , centroids, i, j)
5157 label = dist < min_dist ? j : label
5258 min_dist = dist < min_dist ? dist : min_dist
5359 end
5460
55- final_labels[batch_rand_idx[i]] = label
56- end
57-
58- # TODO : Batch gradient step
59- @inbounds for j in axes (batch_sample, 2 ) # iterate over examples (Step 9)
61+ containers. labels[i] = label
6062
61- # Get cached center/label for this x => labels[batch_rand_idx[j]] (Step 10)
62- label = final_labels[batch_rand_idx[j]]
63+ # #### Batch gradient step #####
64+ # iterate over examples (each column) ==> (Step 9)
65+ # Get cached center/label for each example label = labels[i] => (Step 10)
66+
6367 # Update per-center counts
64- N[label] += isnothing (weights) ? 1 : weights[j ] # verify (Step 11)
68+ N[label] += isnothing (weights) ? 1 : weights[i ] # (Step 11)
6569
6670 # Get per-center learning rate (Step 12)
6771 lr = 1 / N[label]
6872
69- # Take gradient step (Step 13) # TODO : Replace with an allocation-less loop.
70- centroids[:, label] .= (1 - lr) .* centroids[:, label] .+ (lr .* batch_sample [:, j ])
73+ # Take gradient step (Step 13) # TODO : Replace with faster loop?
74+ @views centroids[:, label] .= (1 - lr) .* centroids[:, label] .+ (lr .* X [:, i ])
7175 end
7276
73- # TODO : Reassign all labels based on new centres generated from the latest sample
74- final_labels = reassign_labels (X, metric, final_labels , centroids)
77+ # Reassign all labels based on new centres generated from the latest sample
78+ containers . labels . = reassign_labels (X, metric, containers . labels , centroids)
7579
76- # TODO : Calculate cost on whole dataset after reassignment and check for convergence
77- J = sum_of_squares (X, final_labels, centroids) # just a placeholder for now
80+ # Calculate cost on whole dataset after reassignment and check for convergence
81+ @parallelize 1 ncol sum_of_squares (containers, X, containers. labels, centroids, weights, metric)
82+ J = sum (containers. sum_of_squares)
7883
7984 if verbose
8085 # Show progress and terminate if J stopped decreasing.
8186 println (" Iteration $niters : Jclust = $J " )
8287 end
8388
84- # TODO : Check for early stopping convergence
89+ # Check for early stopping convergence
8590 if (niters > 1 ) & (abs (J - J_previous) < (tol * J))
8691 counter += 1
8792
8893 # Declare convergence if max_no_improvement criterion is met
8994 if counter >= max_no_improvement
9095 converged = true
91- # TODO : Compute label assignment for the complete dataset
92- final_labels = reassign_labels (X, metric, final_labels , centroids)
96+ # Compute label assignment for the complete dataset
97+ containers . labels . = reassign_labels (X, metric, containers . labels , centroids)
9398
94- # TODO : Compute totalcost for the complete dataset
95- J = sum_of_squares (X, final_labels, centroids) # just a placeholder for now
99+ # Compute totalcost for the complete dataset
100+ @parallelize 1 ncol sum_of_squares (containers, X, containers. labels, centroids, weights, metric)
101+ totalcost = sum (containers. sum_of_squares)
96102 break
97103 end
98104 else
99105 counter = 0
106+ end
107+
108+ # Warn users if model doesn't converge at max iterations
109+ if (niters > max_iters) & (! converged)
110+
111+ if verbose
112+ println (" Clustering model failed to converge. Labelling data with latest centroids." )
113+ end
114+ containers. labels = reassign_labels (X, metric, containers. labels, centroids)
100115
116+ # Compute totalcost for unconverged model
117+ @parallelize 1 ncol sum_of_squares (containers, X, containers. labels, centroids, weights, metric)
118+ totalcost = sum (containers. sum_of_squares)
119+ break
101120 end
102121
103122 J_previous = J
104123 niters += 1
105124 end
106125
107- return centroids, niters, converged, final_labels, J # TODO : push learned artifacts to KmeansResult
108- # return KmeansResult(centroids, containers.labels, T[], Int[], T[], totalcost, niters, converged)
126+ # Push learned artifacts to KmeansResult
127+ return KmeansResult (centroids, containers. labels, T[], Int[], T[], totalcost, niters, converged)
109128end
110129
111- # TODO : Only being used to test generic implementation. Get rid off after!
112- function sum_of_squares (x, labels, centre)
113- s = 0.0
114-
115- for i in axes (x, 2 )
116- for j in axes (x, 1 )
117- s += (x[j, i] - centre[j, labels[i]])^ 2
118- end
119- end
120- return s
121- end
122130
123131function reassign_labels (DMatrix, metric, labels, centres)
124132 @inbounds for i in axes (DMatrix, 2 )
@@ -135,3 +143,24 @@ function reassign_labels(DMatrix, metric, labels, centres)
135143 end
136144 return labels
137145end
146+
147+ """
148+ create_containers(::MiniBatch, k, nrow, ncol, n_threads)
149+
150+ Internal function for the creation of all necessary intermidiate structures.
151+
152+ - `centroids_new` - container which holds new positions of centroids
153+ - `centroids_cnt` - container which holds number of points for each centroid
154+ - `labels` - vector which holds labels of corresponding points
155+ - `sum_of_squares` - vector which holds the sum of squares values for each thread
156+ """
157+ function create_containers (alg:: MiniBatch , X, k, nrow, ncol, n_threads)
158+ # Initiate placeholders to avoid allocations
159+ T = eltype (X)
160+ labels = Vector {Int} (undef, ncol) # labels vector
161+ sum_of_squares = Vector {T} (undef, 1 ) # total_sum_calculation
162+ batch_rand_idx = Vector {Int} (undef, alg. b)
163+
164+ return (batch_rand_idx = batch_rand_idx,
165+ labels = labels, sum_of_squares = sum_of_squares)
166+ end
0 commit comments