@@ -166,13 +166,10 @@ keras_postprocess_probs <- function(results, object) {
166166 lvls <- paste0(" class" , 1 : ncol(res ))
167167 }
168168 colnames(res ) <- lvls
169- tibble :: as_tibble(res , .name_repair = " unique" ) % > %
169+ tibble :: as_tibble(res , .name_repair = " unique" ) | >
170170 dplyr :: rename_with(~ paste0(" .pred_" , name , " _" , .x ))
171171 }
172- colnames(res ) <- lvls
173- tibble :: as_tibble(res , .name_repair = " unique" ) %> %
174- dplyr :: rename_with(~ paste0(" .pred_" , name , " _" , .x ))
175- })
172+ )
176173 return (combined_preds )
177174 } else {
178175 # Single output case: results is a matrix/array
@@ -199,26 +196,30 @@ keras_postprocess_probs <- function(results, object) {
199196keras_postprocess_classes <- function (results , object ) {
200197 if (is.list(results ) && ! is.null(names(results ))) {
201198 # Multi-output case: results is a named list of arrays/matrices
202- combined_preds <- purrr :: map2_dfc(results , names(results ), function (res , name ) {
203- lvls <- object $ fit $ lvl [[name ]] # Assuming object$fit$lvl is a named list of levels
204- if (is.null(lvls )) {
205- # Fallback if levels are not specifically named for this output
206- lvls <- paste0(" class" , 1 : ncol(res )) # This might not be correct for classes, but a placeholder
207- }
199+ combined_preds <- purrr :: map2_dfc(
200+ results ,
201+ names(results ),
202+ function (res , name ) {
203+ lvls <- object $ fit $ lvl [[name ]] # Assuming object$fit$lvl is a named list of levels
204+ if (is.null(lvls )) {
205+ # Fallback if levels are not specifically named for this output
206+ lvls <- paste0(" class" , 1 : ncol(res )) # This might not be correct for classes, but a placeholder
207+ }
208208
209- if (ncol(res ) == 1 ) {
210- # Binary classification
211- pred_class <- ifelse(res [, 1 ] > 0.5 , lvls [2 ], lvls [1 ])
212- pred_class <- factor (pred_class , levels = lvls )
213- } else {
214- # Multiclass classification
215- pred_class_int <- apply(res , 1 , which.max )
216- pred_class <- lvls [pred_class_int ]
217- pred_class <- factor (pred_class , levels = lvls )
209+ if (ncol(res ) == 1 ) {
210+ # Binary classification
211+ pred_class <- ifelse(res [, 1 ] > 0.5 , lvls [2 ], lvls [1 ])
212+ pred_class <- factor (pred_class , levels = lvls )
213+ } else {
214+ # Multiclass classification
215+ pred_class_int <- apply(res , 1 , which.max )
216+ pred_class <- lvls [pred_class_int ]
217+ pred_class <- factor (pred_class , levels = lvls )
218+ }
219+ tibble :: tibble(.pred_class = pred_class ) | >
220+ dplyr :: rename_with(~ paste0(" .pred_class_" , name ))
218221 }
219- tibble :: tibble(.pred_class = pred_class ) %> %
220- dplyr :: rename_with(~ paste0(" .pred_class_" , name ))
221- })
222+ )
222223 return (combined_preds )
223224 } else {
224225 # Single output case: results is a matrix/array
0 commit comments