@@ -100,12 +100,11 @@ function MMI.fit(m::KMeans, verbosity::Int, X)
100100 verbose= verbose)
101101
102102 cluster_labels = MMI. categorical (1 : m. k)
103- fitresult = (result. centers, cluster_labels, result. converged)
103+ fitresult = (centers = result. centers, labels = cluster_labels, converged = result. converged)
104104 cache = nothing
105105
106106 report = (cluster_centers= result. centers, iterations= result. iterations,
107- converged= result. converged, totalcost= result. totalcost,
108- assignments= result. assignments, labels= cluster_labels)
107+ totalcost= result. totalcost, assignments= result. assignments, labels= cluster_labels)
109108
110109
111110 """
120119
121120function MMI. fitted_params (model:: KMeans , fitresult)
122121 # Centroids
123- return (cluster_centers = fitresult[ 1 ] , )
122+ return (cluster_centers = fitresult. centers , )
124123end
125124
126125
129128# ###
130129
131130function MMI. transform (m:: KMeans , fitresult, Xnew)
132- # make predictions/assignments using the learned centroids
131+ # transform new data using the fitted centroids.
133132
134133 if ! m. copy
135134 # permutes dimensions of input table without copying and pass to model
@@ -140,13 +139,12 @@ function MMI.transform(m::KMeans, fitresult, Xnew)
140139 end
141140
142141 # Warn users if fitresult is from a `non-converged` fit
143- if ! (fitresult[ end ] )
142+ if ! (fitresult. converged )
144143 @warn " Failed to converge. Using last assignments to make transformations."
145144 end
146145
147146 # use centroid matrix to assign clusters for new data
148- centroids = fitresult[1 ]
149- distances = Distances. pairwise (Distances. SqEuclidean (), DMatrix, centroids; dims= 2 )
147+ distances = Distances. pairwise (Distances. SqEuclidean (), DMatrix, fitresult. centers; dims= 2 )
150148 # preds = argmin.(eachrow(distances))
151149 return MMI. table (distances, prototype= Xnew)
152150end
0 commit comments