Skip to content

Commit 19d9bec

Browse files
committed
Adding a test for this case
1 parent a74bd05 commit 19d9bec

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
test_that("autoplot works with multiple hidden units parameters", {
2+
skip_if_no_keras()
3+
skip_if_not_installed("ggplot2")
4+
5+
# 1. Define a spec with multiple hidden unit parameters
6+
model_name <- "autoplot_spec"
7+
on.exit(suppressMessages(remove_keras_spec(model_name)), add = TRUE)
8+
create_keras_sequential_spec(
9+
model_name = model_name,
10+
layer_blocks = list(
11+
input = function(model, input_shape) {
12+
keras3::keras_model_sequential(input_shape = input_shape)
13+
},
14+
dense1 = function(model, units = 10) {
15+
model |> keras3::layer_dense(units = units)
16+
},
17+
dense2 = function(model, units = 10) {
18+
model |> keras3::layer_dense(units = units)
19+
},
20+
output = function(model, num_classes) {
21+
model |>
22+
keras3::layer_dense(units = num_classes, activation = "softmax")
23+
}
24+
),
25+
mode = "classification"
26+
)
27+
28+
tune_spec <- autoplot_spec(
29+
dense1_units = tune(id = "denseone"),
30+
dense2_units = tune(id = "densetwo")
31+
) |>
32+
set_engine("keras")
33+
34+
# 2. Set up workflow and tuning grid
35+
rec <- recipes::recipe(Species ~ ., data = iris)
36+
tune_wf <- workflows::workflow(rec, tune_spec)
37+
38+
params <- tune::extract_parameter_set_dials(tune_wf)
39+
40+
# The user code should not need to change.
41+
# `hidden_units` will be `kerasnip::hidden_units` which auto-detects the id.
42+
params <- params |>
43+
update(
44+
denseone = hidden_units(range = c(4L, 8L)),
45+
densetwo = hidden_units(range = c(4L, 8L))
46+
)
47+
params$name
48+
params$id
49+
params$source
50+
params$component
51+
params$component_id
52+
params$object
53+
54+
grid <- dials::grid_regular(params, levels = 2)
55+
control <- tune::control_grid(save_pred = FALSE, verbose = FALSE)
56+
57+
# 3. Run tuning
58+
tune_res <- tune::tune_grid(
59+
tune_wf,
60+
resamples = rsample::vfold_cv(iris, v = 2),
61+
grid = grid,
62+
control = control
63+
)
64+
65+
# 4. Assert that autoplot works without error
66+
expect_no_error(
67+
ggplot2::autoplot(tune_res)
68+
)
69+
})

0 commit comments

Comments
 (0)