Skip to content

Commit 7cc6969

Browse files
authored
Merge pull request #81 from Arkoniak/rng_support
Improved support of MLJ
2 parents e0cd5f5 + 54cc9ba commit 7cc6969

21 files changed

+439
-309
lines changed

Project.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,26 @@
11
name = "ParallelKMeans"
22
uuid = "42b8e9d4-006b-409a-8472-7f34b3fb58af"
33
authors = ["Bernard Brenyah", "Andrey Oskin"]
4-
version = "0.1.6"
4+
version = "0.1.7"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
88
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
9+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
910
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1011

1112
[compat]
12-
StatsBase = "0.32, 0.33"
13-
julia = "1.3"
1413
Distances = "0.8.2"
1514
MLJModelInterface = "0.2.1"
15+
StatsBase = "0.32, 0.33"
16+
julia = "1.3"
1617

1718
[extras]
1819
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
1920
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2021
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
2122
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
23+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2224

2325
[targets]
24-
test = ["Test", "Random", "Suppressor", "MLJBase"]
26+
test = ["Test", "Random", "Suppressor", "MLJBase", "StableRNGs"]

benchmark/bench01_distance.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ Random.seed!(2020)
1010
X = rand(3, 100_000)
1111
centroids = rand(3, 2)
1212
d = Vector{Float64}(undef, 100_000)
13-
suite["100kx3"] = @benchmarkable ParallelKMeans.colwise!($d, $X, $centroids)
13+
suite["100kx3"] = @benchmarkable ParallelKMeans.chunk_colwise($d, $X, $centroids, 1, nothing, 1:100_000, 1)
1414

1515
X = rand(10, 100_000)
1616
centroids = rand(10, 2)
1717
d = Vector{Float64}(undef, 100_000)
18-
suite["100kx10"] = @benchmarkable ParallelKMeans.colwise!($d, $X, $centroids)
18+
suite["100kx10"] = @benchmarkable ParallelKMeans.chunk_colwise($d, $X, $centroids, 1, nothing, 1:100_000, 1)
1919

2020
end # module
2121

docs/src/index.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,12 @@ git checkout experimental
7474
- [X] Full Implementation of Triangle inequality based on [Elkan - 2003 Using the Triangle Inequality to Accelerate K-Means"](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf).
7575
- [X] Implementation of [Yinyang K-Means: A Drop-In Replacement of the Classic K-Means with Consistent Speedup](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/ding15.pdf).
7676
- [X] Implementation of [Coresets](http://proceedings.mlr.press/v51/lucic16-supp.pdf).
77-
- [ ] Implementation of [Geometric methods to accelerate k-means algorithm](http://cs.baylor.edu/~hamerly/papers/sdm2016_rysavy_hamerly.pdf).
7877
- [X] Support for weighted K-means.
78+
- [X] Support of MLJ Random generation hyperparameter.
7979
- [ ] Support for other distance metrics supported by [Distances.jl](https://github.com/JuliaStats/Distances.jl#supported-distances).
80-
- [ ] Support of MLJ Random generation hyperparameter.
80+
- [ ] Implementation of [Geometric methods to accelerate k-means algorithm](http://cs.baylor.edu/~hamerly/papers/sdm2016_rysavy_hamerly.pdf).
8181
- [ ] Native support for tabular data inputs outside of MLJModels' interface.
82-
- [ ] Refactoring and finalizaiton of API desgin.
82+
- [ ] Refactoring and finalization of API design.
8383
- [ ] GPU support.
8484
- [ ] Distributed calculations support.
8585
- [ ] Optimization of code base.
@@ -207,6 +207,7 @@ ________________________________________________________________________________
207207
- 0.1.4 Bug fixes.
208208
- 0.1.5 Added `Yinyang` algorithm.
209209
- 0.1.6 Added support for weighted k-means; Added `Coreset` algorithm; improved support for different types of the design matrix.
210+
- 0.1.7 Added `Yinyang` and `Coreset` support in MLJ interface; added `weights` support in MLJ; added RNG seed support in MLJ interface and through all algorithms.
210211

211212
## Contributing
212213

src/ParallelKMeans.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
module ParallelKMeans
22

33
using StatsBase
4+
using Random
45
import MLJModelInterface
56
import Base.Threads: @spawn
67
import Distances
78

8-
const MMI = MLJModelInterface
9-
109
include("kmeans.jl")
1110
include("seeding.jl")
1211
include("lloyd.jl")
1312
include("hamerly.jl")
1413
include("elkan.jl")
1514
include("yinyang.jl")
16-
include("mlj_interface.jl")
1715
include("coreset.jl")
16+
include("mlj_interface.jl")
1817

1918
export kmeans
2019
export Lloyd, Hamerly, Elkan, Yinyang, 阴阳, Coreset

src/coreset.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@ Coreset(alg::AbstractKMeansAlg) = Coreset(100, alg)
3838
function kmeans!(alg::Coreset, containers, X, k, weights;
3939
n_threads = Threads.nthreads(),
4040
k_init = "k-means++", max_iters = 300,
41-
tol = eltype(design_matrix)(1e-6), verbose = false, init = nothing)
41+
tol = eltype(design_matrix)(1e-6), verbose = false,
42+
init = nothing, rng = Random.GLOBAL_RNG)
4243
nrow, ncol = size(X)
43-
centroids = isnothing(init) ? smart_init(X, k, n_threads, init=k_init).centroids : deepcopy(init)
44+
centroids = isnothing(init) ? smart_init(X, k, n_threads, weights, rng, init=k_init).centroids : deepcopy(init)
4445

4546
T = eltype(X)
4647
# Steps 2-4 of the paper's algorithm 3
@@ -54,14 +55,14 @@ function kmeans!(alg::Coreset, containers, X, k, weights;
5455
@parallelize n_threads ncol chunk_update_sensitivity(alg, containers)
5556

5657
# sample from containers.s
57-
coreset_ids = wsample(1:ncol, containers.s, alg.m)
58+
coreset_ids = wsample(rng, 1:ncol, containers.s, alg.m)
5859
coreset = X[:, coreset_ids]
5960
# create new weights as 1/s[i]
6061
coreset_weights = one(T) ./ @view containers.s[coreset_ids]
6162

6263
# run usual kmeans for new set with new weights.
63-
res = kmeans(alg.alg, coreset, k, coreset_weights, tol = tol, max_iters = max_iters,
64-
verbose = verbose, init = centroids, n_threads = n_threads)
64+
res = kmeans(alg.alg, coreset, k, weights = coreset_weights, tol = tol, max_iters = max_iters,
65+
verbose = verbose, init = centroids, n_threads = n_threads, rng = rng)
6566

6667
@parallelize n_threads ncol chunk_apply(alg, containers, res.centers, X, weights)
6768

src/elkan.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ struct Elkan <: AbstractKMeansAlg end
2121
function kmeans!(alg::Elkan, containers, X, k, weights;
2222
n_threads = Threads.nthreads(),
2323
k_init = "k-means++", max_iters = 300,
24-
tol = eltype(X)(1e-6), verbose = false, init = nothing)
24+
tol = eltype(X)(1e-6), verbose = false,
25+
init = nothing, rng = Random.GLOBAL_RNG)
2526
nrow, ncol = size(X)
26-
centroids = init == nothing ? smart_init(X, k, n_threads, weights, init=k_init).centroids : deepcopy(init)
27+
centroids = init == nothing ? smart_init(X, k, n_threads, weights, rng, init=k_init).centroids : deepcopy(init)
2728

2829
update_containers(alg, containers, centroids, n_threads)
2930
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X, weights)

src/hamerly.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ struct Hamerly <: AbstractKMeansAlg end
2121
function kmeans!(alg::Hamerly, containers, X, k, weights;
2222
n_threads = Threads.nthreads(),
2323
k_init = "k-means++", max_iters = 300,
24-
tol = eltype(X)(1e-6), verbose = false, init = nothing)
24+
tol = eltype(X)(1e-6), verbose = false,
25+
init = nothing, rng = Random.GLOBAL_RNG)
2526
nrow, ncol = size(X)
26-
centroids = init == nothing ? smart_init(X, k, n_threads, weights, init=k_init).centroids : deepcopy(init)
27+
centroids = init == nothing ? smart_init(X, k, n_threads, weights, rng, init=k_init).centroids : deepcopy(init)
2728

2829
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X, weights)
2930

src/kmeans.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,16 @@ struct KmeansResult{C<:AbstractMatrix{<:AbstractFloat},D<:Real,WC<:Real} <: Clus
4040
converged::Bool # whether the procedure converged
4141
end
4242

43+
"""
44+
spliiter(n, k)
45+
46+
Internal utility function, splits 1:n sequence to k chunks of approximately same size.
47+
"""
48+
function splitter(n, k)
49+
xz = Int.(ceil.(range(0, n, length = k+1)))
50+
return [xz[i]+1:xz[i+1] for i in 1:k]
51+
end
52+
4353
"""
4454
@parallelize(n_threads, ncol, f)
4555
@@ -120,7 +130,8 @@ function sum_of_squares(containers, x, labels, centre, weights, r, idx)
120130
end
121131

122132
"""
123-
Kmeans([alg::AbstractKMeansAlg,] design_matrix, k; n_threads = nthreads(), k_init="k-means++", max_iters=300, tol=1e-6, verbose=true)
133+
kmeans([alg::AbstractKMeansAlg,] design_matrix, k; n_threads = nthreads(),
134+
k_init="k-means++", max_iters=300, tol=1e-6, verbose=true, rng = Random.GLOBAL_RNG)
124135
125136
This main function employs the K-means algorithm to cluster all examples
126137
in the training data (design_matrix) into k groups using either the
@@ -146,16 +157,18 @@ alternatively one can use `rand` to choose random points for init.
146157
147158
A `KmeansResult` structure representing labels, centroids, and sum_squares is returned.
148159
"""
149-
function kmeans(alg::AbstractKMeansAlg, design_matrix, k, weights = nothing;
160+
function kmeans(alg::AbstractKMeansAlg, design_matrix, k;
161+
weights = nothing,
150162
n_threads = Threads.nthreads(),
151163
k_init = "k-means++", max_iters = 300,
152-
tol = eltype(design_matrix)(1e-6), verbose = false, init = nothing)
164+
tol = eltype(design_matrix)(1e-6), verbose = false,
165+
init = nothing, rng = Random.GLOBAL_RNG)
153166
nrow, ncol = size(design_matrix)
154167
containers = create_containers(alg, design_matrix, k, nrow, ncol, n_threads)
155168

156169
return kmeans!(alg, containers, design_matrix, k, weights, n_threads = n_threads,
157170
k_init = k_init, max_iters = max_iters, tol = tol,
158-
verbose = verbose, init = init)
171+
verbose = verbose, init = init, rng = rng)
159172
end
160173

161174

src/lloyd.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ centroids and so on, which are used during calculations.
1717
function kmeans!(alg::Lloyd, containers, X, k, weights;
1818
n_threads = Threads.nthreads(),
1919
k_init = "k-means++", max_iters = 300,
20-
tol = eltype(design_matrix)(1e-6), verbose = false, init = nothing)
20+
tol = eltype(design_matrix)(1e-6), verbose = false,
21+
init = nothing, rng = Random.GLOBAL_RNG)
2122
nrow, ncol = size(X)
22-
centroids = isnothing(init) ? smart_init(X, k, n_threads, weights, init=k_init).centroids : deepcopy(init)
23+
centroids = isnothing(init) ? smart_init(X, k, n_threads, weights, rng, init=k_init).centroids : deepcopy(init)
2324

2425
T = eltype(X)
2526
converged = false
@@ -61,12 +62,13 @@ function kmeans!(alg::Lloyd, containers, X, k, weights;
6162
return KmeansResult(centroids, containers.labels, T[], Int[], T[], totalcost, niters, converged)
6263
end
6364

64-
kmeans(design_matrix, k, weights = nothing;
65+
kmeans(design_matrix, k;
66+
weights = nothing,
6567
n_threads = Threads.nthreads(),
6668
k_init = "k-means++", max_iters = 300, tol = 1e-6,
67-
verbose = false, init = nothing) =
68-
kmeans(Lloyd(), design_matrix, k, weights; n_threads = n_threads, k_init = k_init, max_iters = max_iters, tol = tol,
69-
verbose = verbose, init = init)
69+
verbose = false, init = nothing, rng = Random.GLOBAL_RNG) =
70+
kmeans(Lloyd(), design_matrix, k; weights = weights, n_threads = n_threads, k_init = k_init, max_iters = max_iters, tol = tol,
71+
verbose = verbose, init = init, rng = rng)
7072

7173
"""
7274
create_containers(::Lloyd, k, nrow, ncol, n_threads)

src/mlj_interface.jl

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,51 @@
11
# Expose all instances of user specified structs and package artifcats.
2+
const MMI = MLJModelInterface
3+
24
const ParallelKMeans_Desc = "Parallel & lightning fast implementation of all available variants of the KMeans clustering algorithm
35
in native Julia. Compatible with Julia 1.3+"
46

57
# availalbe variants for reference
68
const MLJDICT = Dict(:Lloyd => Lloyd(),
79
:Hamerly => Hamerly(),
8-
:Elkan => Elkan())
10+
:Elkan => Elkan(),
11+
:Yinyang => Yinyang(),
12+
:Coreset => Coreset(),
13+
:阴阳 => Coreset())
914

1015
####
1116
#### MODEL DEFINITION
1217
####
1318

1419
mutable struct KMeans <: MMI.Unsupervised
15-
algo::Symbol
20+
algo::Union{Symbol, AbstractKMeansAlg}
1621
k_init::String
1722
k::Int
1823
tol::Float64
1924
max_iters::Int
2025
copy::Bool
2126
threads::Int
27+
rng::Union{AbstractRNG, Int}
28+
weights
2229
init
2330
end
2431

2532

26-
function KMeans(; algo=:Hamerly, k_init="k-means++",
27-
k=3, tol=1e-6, max_iters=300, copy=true,
28-
threads=Threads.nthreads(), init=nothing)
33+
function KMeans(; algo = :Hamerly, k_init = "k-means++",
34+
k = 3, tol = 1e-6, max_iters = 300, copy = true,
35+
threads = Threads.nthreads(), init = nothing,
36+
rng = Random.GLOBAL_RNG, weights = nothing)
2937

30-
model = KMeans(algo, k_init, k, tol, max_iters, copy, threads, init)
38+
model = KMeans(algo, k_init, k, tol, max_iters, copy, threads, rng, weights, init)
3139
message = MMI.clean!(model)
3240
isempty(message) || @warn message
3341
return model
3442
end
3543

3644

3745
function MMI.clean!(m::KMeans)
38-
warning = String[]
46+
warning = String[]
3947

40-
if !(m.algo keys(MLJDICT))
41-
push!(warning, "Unsupported KMeans variant. Defaulting to Hamerly algorithm.")
42-
m.algo = :Hamerly
43-
end
48+
m.algo = clean_algo(m.algo, warning)
4449

4550
if !(m.k_init ["k-means++", "random"])
4651
push!(warning, "Only \"k-means++\" or \"random\" seeding algorithms are supported. Defaulting to k-means++ seeding.")
@@ -89,15 +94,23 @@ function MMI.fit(m::KMeans, verbosity::Int, X)
8994
DMatrix = convert(Array{Float64, 2}, MMI.matrix(X, transpose=true))
9095
end
9196

92-
# lookup available algorithms
93-
algo = MLJDICT[m.algo] # select algo
97+
# setup rng
98+
rng = get_rng(m.rng)
99+
100+
if !isnothing(m.weights) && (size(DMatrix, 2) != length(m.weights))
101+
@warn "Size mismatch, number of points in X $(size(DMatrix, 2)) not equal weights length $(length(m.weights)). Weights parameter ignored."
102+
weights = nothing
103+
else
104+
105+
weights = m.weights
106+
end
94107

95108
# fit model and get results
96109
verbose = verbosity > 0 # Display fitting operations if verbosity > 0
97-
result = ParallelKMeans.kmeans(algo, DMatrix, m.k;
98-
n_threads = m.threads, k_init=m.k_init,
99-
max_iters=m.max_iters, tol=m.tol, init=m.init,
100-
verbose=verbose)
110+
result = ParallelKMeans.kmeans(m.algo, DMatrix, m.k;
111+
n_threads = m.threads, k_init = m.k_init,
112+
max_iters = m.max_iters, tol = m.tol, init = m.init,
113+
rng = rng, verbose = verbose, weights = weights)
101114

102115
cluster_labels = MMI.categorical(1:m.k)
103116
fitresult = (centers = result.centers, labels = cluster_labels, converged = result.converged)
@@ -192,3 +205,20 @@ MMI.metadata_model(KMeans,
192205
weights = false,
193206
descr = ParallelKMeans_Desc,
194207
path = "ParallelKMeans.KMeans")
208+
209+
####
210+
#### Auxiliary functions
211+
####
212+
213+
get_rng(rng::Int) = MersenneTwister(rng)
214+
get_rng(rng) = rng
215+
216+
clean_algo(algo::AbstractKMeansAlg, warning) = algo
217+
function clean_algo(algo::Symbol, warning)
218+
if !(algo keys(MLJDICT))
219+
push!(warning, "Unsupported KMeans variant. Defaulting to Hamerly algorithm.")
220+
return MLJDICT[:Hamerly]
221+
else
222+
return MLJDICT[algo]
223+
end
224+
end

0 commit comments

Comments
 (0)