1- # TODO 1: a using MLJModelInterface or import MLJModelInterface statement
21# Expose all instances of user specified structs and package artifcats.
3- const ParallelKMeans_Desc = " Parallel & lightning fast implementation of all variants of the KMeans clustering algorithm in native Julia."
2+ const ParallelKMeans_Desc = " Parallel & lightning fast implementation of all available variants of the KMeans clustering algorithm
3+ in native Julia. Compatible with Julia 1.3+"
44
55# availalbe variants for reference
66const MLJDICT = Dict (:Lloyd => Lloyd (),
@@ -10,7 +10,6 @@ const MLJDICT = Dict(:Lloyd => Lloyd(),
1010# ###
1111# ### MODEL DEFINITION
1212# ###
13- # TODO 2: MLJ-compatible model types and constructors
1413
1514mutable struct KMeans <: MLJModelInterface.Unsupervised
1615 algo:: Symbol
@@ -40,7 +39,7 @@ function MLJModelInterface.clean!(m::KMeans)
4039 warning = " "
4140
4241 if ! (m. algo ∈ keys (MLJDICT))
43- warning *= " Unsuppored algorithm supplied. Defauting to KMeans++ seeding algorithm."
42+ warning *= " Unsupported KMeans variant, Defauting to KMeans++ seeding algorithm."
4443 m. algo = :Lloyd
4544
4645 elseif m. k_init != " k-means++"
@@ -71,24 +70,22 @@ function MLJModelInterface.clean!(m::KMeans)
7170end
7271
7372
74- # TODO 3: implementation of fit, predict, and fitted_params of the model
7573# ###
7674# ### FIT FUNCTION
7775# ###
7876"""
79- TODO 3.1: Docs
80- # fit the specified struct as a ParaKMeans model
77+ Fit the specified ParaKMeans model constructed by the user.
8178
8279 See also the [package documentation](https://pydatablog.github.io/ParallelKMeans.jl/stable).
8380"""
8481function MLJModelInterface. fit (m:: KMeans , X)
8582 # convert tabular input data into the matrix model expects. Column assumed as features so input data is permuted
8683 if ! m. copy
87- # transpose input table without copying and pass to model
88- DMatrix = convert (Array{Float64, 2 }, X)'
84+ # permutes dimensions of input table without copying and pass to model
85+ DMatrix = convert (Array{Float64, 2 }, MLJModelInterface . matrix ( X)' )
8986 else
90- # tranposes input table as a column major matrix after making a copy of the data
91- DMatrix = MLJModelInterface. matrix (X; transpose= true )
87+ # permutes dimensions of input table as a column major matrix from a copy of the data
88+ DMatrix = convert (Array{Float64, 2 }, MLJModelInterface. matrix (X, transpose= true ) )
9289 end
9390
9491 # lookup available algorithms
@@ -109,9 +106,6 @@ function MLJModelInterface.fit(m::KMeans, X)
109106end
110107
111108
112- """
113- TODO 3.2: Docs
114- """
115109function MLJModelInterface. fitted_params (model:: KMeans , fitresult)
116110 # extract what's relevant from `fitresult`
117111 results, _, _ = fitresult # unpack fitresult
@@ -129,15 +123,26 @@ end
129123# ###
130124# ### PREDICT FUNCTION
131125# ###
132- """
133- TODO 3.3: Docs
134- """
126+
135127function MLJModelInterface. transform (m:: KMeans , fitresult, Xnew)
136128 # make predictions/assignments using the learned centroids
129+
130+ if ! m. copy
131+ # permutes dimensions of input table without copying and pass to model
132+ DMatrix = convert (Array{Float64, 2 }, MLJModelInterface. matrix (Xnew)' )
133+ else
134+ # permutes dimensions of input table as a column major matrix from a copy of the data
135+ DMatrix = convert (Array{Float64, 2 }, MLJModelInterface. matrix (Xnew, transpose= true ))
136+ end
137+
138+ # TODO : Warn users if fitresult is from a `non-converged` fit?
139+ if ! fitresult[end ]. converged
140+ @warn " Failed to converged. Using last assignments to make transformations."
141+ end
142+
143+ # results from fitted model
137144 results = fitresult[1 ]
138- DMatrix = MLJModelInterface. matrix (Xnew, transpose= true )
139145
140- # TODO 3.3.1: Warn users if fitresult is from a `non-converged` fit.
141146 # use centroid matrix to assign clusters for new data
142147 centroids = results. centers
143148 distances = Distances. pairwise (Distances. SqEuclidean (), DMatrix, centroids; dims= 2 )
@@ -153,12 +158,11 @@ end
153158# TODO 4: metadata for the package and for each of the model interfaces
154159metadata_pkg .(KMeans,
155160 name = " ParallelKMeans" ,
156- uuid = " 42b8e9d4-006b-409a-8472-7f34b3fb58af" , # see your Project.toml
157- url = " https://github.com/PyDataBlog/ParallelKMeans.jl" , # URL to your package repo
158- julia = true , # is it written entirely in Julia?
159- license = " MIT" , # your package license
160- is_wrapper = false , # does it wrap around some other package?
161- )
161+ uuid = " 42b8e9d4-006b-409a-8472-7f34b3fb58af" ,
162+ url = " https://github.com/PyDataBlog/ParallelKMeans.jl" ,
163+ julia = true ,
164+ license = " MIT" ,
165+ is_wrapper = false )
162166
163167
164168# Metadata for ParaKMeans model interface
0 commit comments