|
1 | 1 | # Expose all instances of user specified structs and package artifcats. |
| 2 | +const MMI = MLJModelInterface |
| 3 | + |
2 | 4 | const ParallelKMeans_Desc = "Parallel & lightning fast implementation of all available variants of the KMeans clustering algorithm |
3 | 5 | in native Julia. Compatible with Julia 1.3+" |
4 | 6 |
|
5 | 7 | # availalbe variants for reference |
6 | 8 | const MLJDICT = Dict(:Lloyd => Lloyd(), |
7 | 9 | :Hamerly => Hamerly(), |
8 | | - :Elkan => Elkan()) |
| 10 | + :Elkan => Elkan(), |
| 11 | + :Yinyang => Yinyang(), |
| 12 | + :Coreset => Coreset(), |
| 13 | + :阴阳 => Coreset()) |
9 | 14 |
|
10 | 15 | #### |
11 | 16 | #### MODEL DEFINITION |
12 | 17 | #### |
13 | 18 |
|
14 | 19 | mutable struct KMeans <: MMI.Unsupervised |
15 | | - algo::Symbol |
| 20 | + algo::Union{Symbol, AbstractKMeansAlg} |
16 | 21 | k_init::String |
17 | 22 | k::Int |
18 | 23 | tol::Float64 |
19 | 24 | max_iters::Int |
20 | 25 | copy::Bool |
21 | 26 | threads::Int |
| 27 | + rng::Union{AbstractRNG, Int} |
| 28 | + weights |
22 | 29 | init |
23 | 30 | end |
24 | 31 |
|
25 | 32 |
|
26 | | -function KMeans(; algo=:Hamerly, k_init="k-means++", |
27 | | - k=3, tol=1e-6, max_iters=300, copy=true, |
28 | | - threads=Threads.nthreads(), init=nothing) |
| 33 | +function KMeans(; algo = :Hamerly, k_init = "k-means++", |
| 34 | + k = 3, tol = 1e-6, max_iters = 300, copy = true, |
| 35 | + threads = Threads.nthreads(), init = nothing, |
| 36 | + rng = Random.GLOBAL_RNG, weights = nothing) |
29 | 37 |
|
30 | | - model = KMeans(algo, k_init, k, tol, max_iters, copy, threads, init) |
| 38 | + model = KMeans(algo, k_init, k, tol, max_iters, copy, threads, rng, weights, init) |
31 | 39 | message = MMI.clean!(model) |
32 | 40 | isempty(message) || @warn message |
33 | 41 | return model |
34 | 42 | end |
35 | 43 |
|
36 | 44 |
|
37 | 45 | function MMI.clean!(m::KMeans) |
38 | | - warning = String[] |
| 46 | + warning = String[] |
39 | 47 |
|
40 | | - if !(m.algo ∈ keys(MLJDICT)) |
41 | | - push!(warning, "Unsupported KMeans variant. Defaulting to Hamerly algorithm.") |
42 | | - m.algo = :Hamerly |
43 | | - end |
| 48 | + m.algo = clean_algo(m.algo, warning) |
44 | 49 |
|
45 | 50 | if !(m.k_init ∈ ["k-means++", "random"]) |
46 | 51 | push!(warning, "Only \"k-means++\" or \"random\" seeding algorithms are supported. Defaulting to k-means++ seeding.") |
@@ -89,15 +94,23 @@ function MMI.fit(m::KMeans, verbosity::Int, X) |
89 | 94 | DMatrix = convert(Array{Float64, 2}, MMI.matrix(X, transpose=true)) |
90 | 95 | end |
91 | 96 |
|
92 | | - # lookup available algorithms |
93 | | - algo = MLJDICT[m.algo] # select algo |
| 97 | + # setup rng |
| 98 | + rng = get_rng(m.rng) |
| 99 | + |
| 100 | + if !isnothing(m.weights) && (size(DMatrix, 2) != length(m.weights)) |
| 101 | + @warn "Size mismatch, number of points in X $(size(DMatrix, 2)) not equal weights length $(length(m.weights)). Weights parameter ignored." |
| 102 | + weights = nothing |
| 103 | + else |
| 104 | + |
| 105 | + weights = m.weights |
| 106 | + end |
94 | 107 |
|
95 | 108 | # fit model and get results |
96 | 109 | verbose = verbosity > 0 # Display fitting operations if verbosity > 0 |
97 | | - result = ParallelKMeans.kmeans(algo, DMatrix, m.k; |
98 | | - n_threads = m.threads, k_init=m.k_init, |
99 | | - max_iters=m.max_iters, tol=m.tol, init=m.init, |
100 | | - verbose=verbose) |
| 110 | + result = ParallelKMeans.kmeans(m.algo, DMatrix, m.k; |
| 111 | + n_threads = m.threads, k_init = m.k_init, |
| 112 | + max_iters = m.max_iters, tol = m.tol, init = m.init, |
| 113 | + rng = rng, verbose = verbose, weights = weights) |
101 | 114 |
|
102 | 115 | cluster_labels = MMI.categorical(1:m.k) |
103 | 116 | fitresult = (centers = result.centers, labels = cluster_labels, converged = result.converged) |
@@ -192,3 +205,20 @@ MMI.metadata_model(KMeans, |
192 | 205 | weights = false, |
193 | 206 | descr = ParallelKMeans_Desc, |
194 | 207 | path = "ParallelKMeans.KMeans") |
| 208 | + |
| 209 | +#### |
| 210 | +#### Auxiliary functions |
| 211 | +#### |
| 212 | + |
| 213 | +get_rng(rng::Int) = MersenneTwister(rng) |
| 214 | +get_rng(rng) = rng |
| 215 | + |
| 216 | +clean_algo(algo::AbstractKMeansAlg, warning) = algo |
| 217 | +function clean_algo(algo::Symbol, warning) |
| 218 | + if !(algo ∈ keys(MLJDICT)) |
| 219 | + push!(warning, "Unsupported KMeans variant. Defaulting to Hamerly algorithm.") |
| 220 | + return MLJDICT[:Hamerly] |
| 221 | + else |
| 222 | + return MLJDICT[algo] |
| 223 | + end |
| 224 | +end |
0 commit comments