Skip to content

Commit 7b0bbec

Browse files
authored
Merge pull request #5 from davidrsch/remove_keras_spec_issue
Fixing issue with remove_keras_spec
2 parents cab9e65 + 91af678 commit 7b0bbec

File tree

4 files changed

+55
-38
lines changed

4 files changed

+55
-38
lines changed

R/create_keras_spec.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ create_keras_spec <- function(
117117
register_core_model(model_name, mode)
118118
register_model_args(model_name, args_info$parsnip_names)
119119
register_fit_predict(model_name, mode, layer_blocks)
120-
register_update_method(model_name, args_info$parsnip_names)
120+
register_update_method(model_name, args_info$parsnip_names, env = env)
121121

122122
env_poke(env, model_name, spec_fun)
123123
invisible(NULL)

R/create_keras_spec_helpers.R

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,9 @@ register_fit_predict <- function(model_name, mode, layer_blocks) {
529529
#' @param model_name The name of the new model.
530530
#' @param parsnip_names A character vector of all argument names.
531531
#' @return Invisibly returns `NULL`. Called for its side effects.
532+
#' @param env The environment in which to create the update method.
532533
#' @noRd
533-
register_update_method <- function(model_name, parsnip_names) {
534+
register_update_method <- function(model_name, parsnip_names, env) {
534535
# Build function signature
535536
update_args_list <- c(
536537
list(object = rlang::missing_arg(), parameters = rlang::expr(NULL)),
@@ -572,6 +573,8 @@ register_update_method <- function(model_name, parsnip_names) {
572573
body = update_body
573574
)
574575
method_name <- paste0("update.", model_name)
575-
rlang::env_poke(environment(), method_name, update_func)
576-
registerS3method("update", model_name, update_func, envir = environment())
576+
# Poke the function into the target environment (e.g., .GlobalEnv) so that
577+
# S3 dispatch can find it.
578+
rlang::env_poke(env, method_name, update_func)
579+
registerS3method("update", model_name, update_func, envir = env)
577580
}

R/remove_spec.R

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,35 @@
2929
#' !exists("my_temp_model")
3030
#' }
3131
remove_keras_spec <- function(model_name, env = parent.frame()) {
32-
spec_found <- FALSE
33-
if (exists(model_name, envir = env, inherits = FALSE)) {
34-
obj <- get(model_name, envir = env)
35-
if (is.function(obj)) {
36-
remove(list = model_name, envir = env)
37-
spec_found <- TRUE
38-
}
32+
# 1. Remove the spec + update fn from the user env
33+
if (
34+
exists(model_name, envir = env, inherits = FALSE) &&
35+
is.function(get(model_name, envir = env))
36+
) {
37+
remove(list = model_name, envir = env)
38+
}
39+
update_fn <- paste0("update.", model_name)
40+
if (exists(update_fn, envir = env, inherits = FALSE)) {
41+
remove(list = update_fn, envir = env)
42+
}
43+
44+
# 2. Nuke every parsnip object whose name starts with model_name
45+
model_env <- parsnip:::get_model_env()
46+
all_regs <- ls(envir = model_env)
47+
to_kill <- grep(paste0("^", model_name), all_regs, value = TRUE)
48+
if (length(to_kill)) {
49+
rm(list = to_kill, envir = model_env)
50+
message(
51+
"Removed from parsnip registry objects: ",
52+
paste(to_kill, collapse = ", ")
53+
)
3954
}
4055

41-
# Also remove the associated update method
42-
update_method_name <- paste0("update.", model_name)
43-
# The update method is in the package namespace. `environment()` inside a
44-
# package function returns the package namespace.
45-
pkg_env <- environment()
46-
if (exists(update_method_name, envir = pkg_env, inherits = FALSE)) {
47-
remove(list = update_method_name, envir = pkg_env)
56+
# 3. Remove the entry in get_model_env()$models
57+
if ("models" %in% all_regs && model_name %in% model_env$models) {
58+
model_env$models <- model_env$models[-which(model_name == model_env$models)]
59+
message("Removed '", model_name, "' from parsnip:::get_model_env()$models")
4860
}
4961

50-
invisible(spec_found)
62+
invisible(TRUE)
5163
}
Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,30 @@
11
test_that("E2E: Model spec removal works", {
2-
input_block_rm <- function(model, input_shape) {
2+
skip_if_no_keras()
3+
4+
model_name <- "removable_model"
5+
6+
input_block <- function(model, input_shape) {
37
keras3::keras_model_sequential(input_shape = input_shape)
48
}
5-
hidden_block_rm <- function(model, units = 16) {
6-
model |> keras3::layer_dense(units = units, activation = "relu")
9+
output_block <- function(model) {
10+
model |> keras3::layer_dense(units = 1)
711
}
8-
output_block_rm <- function(model, num_classes) {
9-
model |> keras3::layer_dense(units = num_classes, activation = "softmax")
10-
}
11-
12-
model_to_remove <- "e2e_mlp_to_remove"
1312

1413
create_keras_spec(
15-
model_name = model_to_remove,
16-
layer_blocks = list(
17-
input = input_block_rm,
18-
hidden = hidden_block_rm,
19-
output = output_block_rm
20-
),
21-
mode = "classification"
14+
model_name = model_name,
15+
layer_blocks = list(input = input_block, output = output_block),
16+
mode = "regression"
2217
)
2318

24-
expect_true(exists(model_to_remove, inherits = FALSE))
25-
expect_true(remove_keras_spec(model_to_remove))
26-
expect_false(exists(model_to_remove, inherits = FALSE))
27-
expect_false(remove_keras_spec("a_non_existent_model"))
19+
update_method_name <- paste0("update.", model_name)
20+
21+
expect_true(exists(model_name, inherits = FALSE))
22+
expect_true(exists(update_method_name, inherits = FALSE))
23+
expect_error(parsnip:::check_model_doesnt_exist(model_name))
24+
25+
remove_keras_spec(model_name)
26+
27+
expect_false(exists(model_name, inherits = FALSE))
28+
expect_false(exists(update_method_name, inherits = FALSE))
29+
expect_no_error(parsnip:::check_model_doesnt_exist(model_name))
2830
})

0 commit comments

Comments
 (0)