Skip to content

Commit 61f7ecc

Browse files
Update R/register_fit_predict.R
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 94d70bc commit 61f7ecc

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

R/register_fit_predict.R

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,18 @@ keras_postprocess_numeric <- function(results, object) {
156156
keras_postprocess_probs <- function(results, object) {
157157
if (is.list(results) && !is.null(names(results))) {
158158
# Multi-output case: results is a named list of arrays/matrices
159-
combined_preds <- purrr::map2_dfc(results, names(results), function(res, name) {
160-
lvls <- object$fit$lvl[[name]] # Assuming object$fit$lvl is a named list of levels
161-
if (is.null(lvls)) {
162-
# Fallback if levels are not specifically named for this output
163-
lvls <- paste0("class", 1:ncol(res))
159+
combined_preds <- purrr::map2_dfc(
160+
results,
161+
names(results),
162+
function(res, name) {
163+
lvls <- object$fit$lvl[[name]] # Assuming object$fit$lvl is a named list of levels
164+
if (is.null(lvls)) {
165+
# Fallback if levels are not specifically named for this output
166+
lvls <- paste0("class", 1:ncol(res))
167+
}
168+
colnames(res) <- lvls
169+
tibble::as_tibble(res, .name_repair = "unique") %>%
170+
dplyr::rename_with(~ paste0(".pred_", name, "_", .x))
164171
}
165172
colnames(res) <- lvls
166173
tibble::as_tibble(res, .name_repair = "unique") %>%

0 commit comments

Comments
 (0)