11"""
22 MiniBatch(b::Int)
3+ `b` represents the size of the batch which should be sampled.
34
45 Sculley et al. 2007 Mini batch k-means algorithm implementation.
56
1617
1718MiniBatch () = MiniBatch (100 )
1819
19- function kmeans! (alg:: MiniBatch , X, k;
20- weights = nothing , metric = Euclidean (), n_threads = Threads. nthreads (),
20+ function kmeans! (alg:: MiniBatch , containers, X, k,
21+ weights = nothing , metric = Euclidean (); n_threads = Threads. nthreads (),
2122 k_init = " k-means++" , init = nothing , max_iters = 300 ,
2223 tol = eltype (X)(1e-6 ), max_no_improvement = 10 , verbose = false , rng = Random. GLOBAL_RNG)
2324
@@ -32,15 +33,14 @@ function kmeans!(alg::MiniBatch, X, k;
3233 N = zeros (T, k)
3334
3435 # Initialize nearest centers for both batch and whole dataset labels
35- final_labels = Vector {Int} (undef, ncol) # dataset labels
36-
3736 converged = false
3837 niters = 0
3938 counter = 0
4039 J_previous = zero (T)
4140 J = zero (T)
41+ totalcost = zero (T)
4242
43- # TODO : Main Steps. Batch update centroids until convergence
43+ # Main Steps. Batch update centroids until convergence
4444 while niters <= max_iters # Step 4 in paper
4545
4646 # b examples picked randomly from X (Step 5 in paper)
@@ -57,16 +57,17 @@ function kmeans!(alg::MiniBatch, X, k;
5757 min_dist = dist < min_dist ? dist : min_dist
5858 end
5959
60- final_labels [i] = label
60+ containers . labels [i] = label
6161 end
6262
63- # TODO : Batch gradient step
63+ # Batch gradient step
6464 @inbounds for j in batch_rand_idx # iterate over examples (Step 9)
6565
6666 # Get cached center/label for this x => (Step 10)
67- label = final_labels[j]
67+ label = containers. labels[j]
68+
6869 # Update per-center counts
69- N[label] += isnothing (weights) ? 1 : weights[j] # verify (Step 11)
70+ N[label] += isnothing (weights) ? 1 : weights[j] # (Step 11)
7071
7172 # Get per-center learning rate (Step 12)
7273 lr = 1 / N[label]
@@ -75,65 +76,57 @@ function kmeans!(alg::MiniBatch, X, k;
7576 @views centroids[:, label] .= (1 - lr) .* centroids[:, label] .+ (lr .* X[:, j])
7677 end
7778
78- # TODO : Reassign all labels based on new centres generated from the latest sample
79- final_labels = reassign_labels (X, metric, final_labels , centroids)
79+ # Reassign all labels based on new centres generated from the latest sample
80+ containers . labels . = reassign_labels (X, metric, containers . labels , centroids)
8081
81- # TODO : Calculate cost on whole dataset after reassignment and check for convergence
82- J = sum_of_squares (X, final_labels, centroids) # just a placeholder for now
82+ # Calculate cost on whole dataset after reassignment and check for convergence
83+ @parallelize 1 ncol sum_of_squares (containers, X, containers. labels, centroids, weights, metric)
84+ J = sum (containers. sum_of_squares)
8385
8486 if verbose
8587 # Show progress and terminate if J stopped decreasing.
8688 println (" Iteration $niters : Jclust = $J " )
8789 end
8890
89- # TODO : Check for early stopping convergence
91+ # Check for early stopping convergence
9092 if (niters > 1 ) & (abs (J - J_previous) < (tol * J))
9193 counter += 1
9294
9395 # Declare convergence if max_no_improvement criterion is met
9496 if counter >= max_no_improvement
9597 converged = true
96- # TODO : Compute label assignment for the complete dataset
97- final_labels = reassign_labels (X, metric, final_labels , centroids)
98+ # Compute label assignment for the complete dataset
99+ containers . labels . = reassign_labels (X, metric, containers . labels , centroids)
98100
99- # TODO : Compute totalcost for the complete dataset
100- J = sum_of_squares (X, final_labels, centroids) # just a placeholder for now
101+ # Compute totalcost for the complete dataset
102+ @parallelize 1 ncol sum_of_squares (containers, X, containers. labels, centroids, weights, metric)
103+ totalcost = sum (containers. sum_of_squares)
101104 break
102105 end
103106 else
104107 counter = 0
105108 end
106109
107- # TODO : Warn users if model doesn't converge at max iterations
110+ # Warn users if model doesn't converge at max iterations
108111 if (niters > max_iters) & (! converged)
109112
110113 println (" Clustering model failed to converge. Labelling data with latest centroids." )
111- final_labels = reassign_labels (X, metric, final_labels , centroids)
114+ containers . labels = reassign_labels (X, metric, containers . labels , centroids)
112115
113- # TODO : Compute totalcost for unconverged model
114- J = sum_of_squares (X, final_labels, centroids)
116+ # Compute totalcost for unconverged model
117+ @parallelize 1 ncol sum_of_squares (containers, X, containers. labels, centroids, weights, metric)
118+ totalcost = sum (containers. sum_of_squares)
115119 break
116120 end
117121
118122 J_previous = J
119123 niters += 1
120124 end
121125
122- return centroids, niters, converged, final_labels, J # TODO : push learned artifacts to KmeansResult
123- # return KmeansResult(centroids, containers.labels, T[], Int[], T[], totalcost, niters, converged)
126+ # Push learned artifacts to KmeansResult
127+ return KmeansResult (centroids, containers. labels, T[], Int[], T[], totalcost, niters, converged)
124128end
125129
126- # TODO : Only being used to test generic implementation. Get rid off after!
127- function sum_of_squares (x, labels, centre)
128- s = 0.0
129-
130- for i in axes (x, 2 )
131- for j in axes (x, 1 )
132- s += (x[j, i] - centre[j, labels[i]])^ 2
133- end
134- end
135- return s
136- end
137130
138131function reassign_labels (DMatrix, metric, labels, centres)
139132 @inbounds for i in axes (DMatrix, 2 )
@@ -159,25 +152,16 @@ Internal function for the creation of all necessary intermidiate structures.
159152- `centroids_new` - container which holds new positions of centroids
160153- `centroids_cnt` - container which holds number of points for each centroid
161154- `labels` - vector which holds labels of corresponding points
155+ - `sum_of_squares` - vector which holds the sum of squares values for each thread
162156"""
163157function create_containers (:: MiniBatch , X, k, nrow, ncol, n_threads)
164- T = eltype (X)
165- lng = n_threads + 1
166- centroids_new = Vector {Matrix{T}} (undef, lng)
167- centroids_cnt = Vector {Vector{T}} (undef, lng)
168-
169- for i in 1 : lng
170- centroids_new[i] = Matrix {T} (undef, nrow, k)
171- centroids_cnt[i] = Vector {Int} (undef, k)
172- end
173-
174- labels = Vector {Int} (undef, ncol)
175-
176- J = Vector {T} (undef, n_threads)
177-
178- # total_sum_calculation
179- sum_of_squares = Vector {T} (undef, n_threads)
158+ # Initiate placeholders to avoid allocations
159+ T = eltype (X)
160+ centroids_new = Matrix {T} (undef, nrow, k) # main centroids
161+ centroids_cnt = Vector {T} (undef, k) # centroids counter
162+ labels = Vector {Int} (undef, ncol) # labels vector
163+ sum_of_squares = Vector {T} (undef, 1 ) # total_sum_calculation
180164
181165 return (centroids_new = centroids_new, centroids_cnt = centroids_cnt,
182- labels = labels, J = J, sum_of_squares = sum_of_squares)
166+ labels = labels, sum_of_squares = sum_of_squares)
183167end
0 commit comments