diff --git a/R/remove_keras_spec.R b/R/remove_keras_spec.R index f0091bf..dd5421e 100644 --- a/R/remove_keras_spec.R +++ b/R/remove_keras_spec.R @@ -74,7 +74,13 @@ remove_keras_spec <- function(model_name, env = parent.frame()) { # 2. Nuke every parsnip object whose name starts with model_name model_env <- get_model_env() all_regs <- ls(envir = model_env) - to_kill <- grep(paste0("^", model_name), all_regs, value = TRUE) + to_kill <- intersect( + all_regs, + paste0( + model_name, + c("", "_args", "_encoding", "_fit", "_modes", "_pkgs", "_predict") + ) + ) if (length(to_kill)) { rm(list = to_kill, envir = model_env) message( diff --git a/tests/testthat/test_e2e_spec_removal.R b/tests/testthat/test_e2e_spec_removal.R index d10aa46..ef57821 100644 --- a/tests/testthat/test_e2e_spec_removal.R +++ b/tests/testthat/test_e2e_spec_removal.R @@ -28,3 +28,44 @@ test_that("E2E: Model spec removal works", { expect_false(exists(update_method_name, inherits = FALSE)) expect_no_error(parsnip:::check_model_doesnt_exist(model_name)) }) + +test_that("E2E: Model spec removal is not too aggressive", { + skip_if_no_keras() + + model_name <- "my_mlp" + model_name_2 <- "my_mlp_2" + + input_block <- function(model, input_shape) { + keras3::keras_model_sequential(input_shape = input_shape) + } + output_block <- function(model) { + model |> keras3::layer_dense(units = 1) + } + + create_keras_sequential_spec( + model_name = model_name, + layer_blocks = list(input = input_block, output = output_block), + mode = "regression" + ) + + create_keras_sequential_spec( + model_name = model_name_2, + layer_blocks = list(input = input_block, output = output_block), + mode = "regression" + ) + + expect_true(exists(model_name, inherits = FALSE)) + expect_true(exists(model_name_2, inherits = FALSE)) + expect_error(parsnip:::check_model_doesnt_exist(model_name)) + expect_error(parsnip:::check_model_doesnt_exist(model_name_2)) + + remove_keras_spec(model_name) + + expect_false(exists(model_name, inherits = FALSE)) + expect_true(exists(model_name_2, inherits = FALSE)) + expect_no_error(parsnip:::check_model_doesnt_exist(model_name)) + expect_error(parsnip:::check_model_doesnt_exist(model_name_2)) + + # cleanup + remove_keras_spec(model_name_2) +}) \ No newline at end of file