Skip to content

Commit 3e76530

Browse files
committed
fixing code typos
1 parent ec63e4b commit 3e76530

File tree

2 files changed

+29
-28
lines changed

2 files changed

+29
-28
lines changed

R/compile_keras_grid.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ extract_valid_grid <- function(compiled_grid) {
216216
"`compiled_grid` must be a data frame produced by `compile_keras_grid()`."
217217
)
218218
}
219-
compiled_grid %>%
220-
dplyr::filter(is.na(error)) %>%
219+
compiled_grid |>
220+
dplyr::filter(is.na(error)) |>
221221
dplyr::select(-c(compiled_model, error))
222222
}
223223

@@ -257,7 +257,7 @@ inform_errors <- function(compiled_grid, n = 10) {
257257
"`compiled_grid` must be a data frame produced by `compile_keras_grid()`."
258258
)
259259
}
260-
error_grid <- compiled_grid %>%
260+
error_grid <- compiled_grid |>
261261
dplyr::filter(!is.na(error))
262262
if (nrow(error_grid) > 0) {
263263
cli::cli_h1("Compilation Errors Summary")
@@ -267,7 +267,7 @@ inform_errors <- function(compiled_grid, n = 10) {
267267

268268
for (i in 1:min(nrow(error_grid), n)) {
269269
row <- error_grid[i, ]
270-
params <- row %>% dplyr::select(-c(compiled_model, error))
270+
params <- row |> dplyr::select(-c(compiled_model, error))
271271
cli::cli_h2("Error {i}/{nrow(error_grid)}")
272272
cli::cli_text("Hyperparameters:")
273273
cli::cli_bullets(paste0(names(params), ": ", as.character(params)))
@@ -281,4 +281,4 @@ inform_errors <- function(compiled_grid, n = 10) {
281281
cli::cli_alert_success("All models compiled successfully!")
282282
}
283283
invisible(compiled_grid)
284-
}
284+
}

R/register_fit_predict.R

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,10 @@ keras_postprocess_probs <- function(results, object) {
166166
lvls <- paste0("class", 1:ncol(res))
167167
}
168168
colnames(res) <- lvls
169-
tibble::as_tibble(res, .name_repair = "unique") %>%
169+
tibble::as_tibble(res, .name_repair = "unique") |>
170170
dplyr::rename_with(~ paste0(".pred_", name, "_", .x))
171171
}
172-
colnames(res) <- lvls
173-
tibble::as_tibble(res, .name_repair = "unique") %>%
174-
dplyr::rename_with(~ paste0(".pred_", name, "_", .x))
175-
})
172+
)
176173
return(combined_preds)
177174
} else {
178175
# Single output case: results is a matrix/array
@@ -199,26 +196,30 @@ keras_postprocess_probs <- function(results, object) {
199196
keras_postprocess_classes <- function(results, object) {
200197
if (is.list(results) && !is.null(names(results))) {
201198
# Multi-output case: results is a named list of arrays/matrices
202-
combined_preds <- purrr::map2_dfc(results, names(results), function(res, name) {
203-
lvls <- object$fit$lvl[[name]] # Assuming object$fit$lvl is a named list of levels
204-
if (is.null(lvls)) {
205-
# Fallback if levels are not specifically named for this output
206-
lvls <- paste0("class", 1:ncol(res)) # This might not be correct for classes, but a placeholder
207-
}
199+
combined_preds <- purrr::map2_dfc(
200+
results,
201+
names(results),
202+
function(res, name) {
203+
lvls <- object$fit$lvl[[name]] # Assuming object$fit$lvl is a named list of levels
204+
if (is.null(lvls)) {
205+
# Fallback if levels are not specifically named for this output
206+
lvls <- paste0("class", 1:ncol(res)) # This might not be correct for classes, but a placeholder
207+
}
208208

209-
if (ncol(res) == 1) {
210-
# Binary classification
211-
pred_class <- ifelse(res[, 1] > 0.5, lvls[2], lvls[1])
212-
pred_class <- factor(pred_class, levels = lvls)
213-
} else {
214-
# Multiclass classification
215-
pred_class_int <- apply(res, 1, which.max)
216-
pred_class <- lvls[pred_class_int]
217-
pred_class <- factor(pred_class, levels = lvls)
209+
if (ncol(res) == 1) {
210+
# Binary classification
211+
pred_class <- ifelse(res[, 1] > 0.5, lvls[2], lvls[1])
212+
pred_class <- factor(pred_class, levels = lvls)
213+
} else {
214+
# Multiclass classification
215+
pred_class_int <- apply(res, 1, which.max)
216+
pred_class <- lvls[pred_class_int]
217+
pred_class <- factor(pred_class, levels = lvls)
218+
}
219+
tibble::tibble(.pred_class = pred_class) |>
220+
dplyr::rename_with(~ paste0(".pred_class_", name))
218221
}
219-
tibble::tibble(.pred_class = pred_class) %>%
220-
dplyr::rename_with(~ paste0(".pred_class_", name))
221-
})
222+
)
222223
return(combined_preds)
223224
} else {
224225
# Single output case: results is a matrix/array

0 commit comments

Comments
 (0)