Skip to content

Commit 3bf92d1

Browse files
committed
Adding tests
1 parent 2cc6ad4 commit 3bf92d1

File tree

1 file changed

+184
-0
lines changed

1 file changed

+184
-0
lines changed
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
test_that("E2E: Functional spec (regression) works", {
2+
skip_if_no_keras()
3+
4+
# Define blocks for a simple forked functional model
5+
input_block <- function(input_shape) keras3::layer_input(shape = input_shape)
6+
path_block <- function(tensor, units = 8) {
7+
tensor |> keras3::layer_dense(units = units, activation = "relu")
8+
}
9+
concat_block <- function(input_a, input_b) {
10+
keras3::layer_concatenate(list(input_a, input_b))
11+
}
12+
output_block_reg <- function(tensor) keras3::layer_dense(tensor, units = 1)
13+
14+
model_name <- "e2e_func_reg"
15+
on.exit(suppressMessages(remove_keras_spec(model_name)), add = TRUE)
16+
17+
# Create a spec with two parallel paths that are then concatenated
18+
create_keras_functional_spec(
19+
model_name = model_name,
20+
layer_blocks = list(
21+
main_input = input_block,
22+
path_a = inp_spec(path_block, "main_input"),
23+
path_b = inp_spec(path_block, "main_input"),
24+
concatenated = inp_spec(
25+
concat_block,
26+
c(path_a = "input_a", path_b = "input_b")
27+
),
28+
output = inp_spec(output_block_reg, "concatenated")
29+
),
30+
mode = "regression"
31+
)
32+
33+
spec <- e2e_func_reg(
34+
path_a_units = 32,
35+
path_b_units = 16,
36+
fit_epochs = 2
37+
) |>
38+
set_engine("keras")
39+
40+
data <- mtcars
41+
rec <- recipe(mpg ~ ., data = data)
42+
wf <- workflows::workflow(rec, spec)
43+
44+
expect_no_error(fit_obj <- parsnip::fit(wf, data = data))
45+
expect_s3_class(fit_obj, "workflow")
46+
47+
preds <- predict(fit_obj, new_data = data[1:5, ])
48+
expect_s3_class(preds, "tbl_df")
49+
expect_equal(names(preds), ".pred")
50+
expect_equal(nrow(preds), 5)
51+
expect_true(is.numeric(preds$.pred))
52+
})
53+
54+
55+
test_that("E2E: Functional spec (classification) works", {
56+
skip_if_no_keras()
57+
58+
# Define blocks for a simple forked functional model
59+
input_block <- function(input_shape) keras3::layer_input(shape = input_shape)
60+
# Add a default to `units` to work around a bug in the doc generator
61+
# when handling args with no default. This doesn't affect runtime as the
62+
# value is always overridden.
63+
path_block <- function(tensor, units = 16) {
64+
tensor |> keras3::layer_dense(units = units, activation = "relu")
65+
}
66+
concat_block <- function(input_a, input_b) {
67+
keras3::layer_concatenate(list(input_a, input_b))
68+
}
69+
output_block_class <- function(tensor, num_classes) {
70+
tensor |> keras3::layer_dense(units = num_classes, activation = "softmax")
71+
}
72+
73+
model_name <- "e2e_func_class"
74+
on.exit(suppressMessages(remove_keras_spec(model_name)), add = TRUE)
75+
76+
# Create a spec with two parallel paths that are then concatenated
77+
create_keras_functional_spec(
78+
model_name = model_name,
79+
layer_blocks = list(
80+
main_input = input_block,
81+
path_a = inp_spec(path_block, "main_input"),
82+
path_b = inp_spec(path_block, "main_input"),
83+
concatenated = inp_spec(
84+
concat_block,
85+
c(path_a = "input_a", path_b = "input_b")
86+
),
87+
output = inp_spec(output_block_class, "concatenated")
88+
),
89+
mode = "classification"
90+
)
91+
92+
spec <- e2e_func_class(
93+
path_a_units = 8,
94+
path_b_units = 4,
95+
fit_epochs = 2
96+
) |>
97+
set_engine("keras")
98+
99+
data <- iris
100+
rec <- recipe(Species ~ ., data = data)
101+
wf <- workflows::workflow(rec, spec)
102+
103+
expect_no_error(fit_obj <- parsnip::fit(wf, data = data))
104+
expect_s3_class(fit_obj, "workflow")
105+
106+
preds_class <- predict(fit_obj, new_data = data[1:5, ], type = "class")
107+
expect_s3_class(preds_class, "tbl_df")
108+
expect_equal(names(preds_class), ".pred_class")
109+
expect_equal(levels(preds_class$.pred_class), levels(data$Species))
110+
111+
preds_prob <- predict(fit_obj, new_data = data[1:5, ], type = "prob")
112+
expect_s3_class(preds_prob, "tbl_df")
113+
expect_equal(names(preds_prob), paste0(".pred_", levels(data$Species)))
114+
expect_true(all(abs(rowSums(preds_prob) - 1) < 1e-5))
115+
})
116+
117+
118+
test_that("E2E: Functional spec tuning (including repetition) works", {
119+
skip_if_no_keras()
120+
121+
input_block <- function(input_shape) keras3::layer_input(shape = input_shape)
122+
# Add a default to `units` to work around a bug in the doc generator
123+
# when handling args with no default. This doesn't affect runtime as the
124+
# value is always overridden by the tuning grid.
125+
dense_block <- function(tensor, units = 16) {
126+
tensor |> keras3::layer_dense(units = units, activation = "relu")
127+
}
128+
output_block_class <- function(tensor, num_classes) {
129+
tensor |> keras3::layer_dense(units = num_classes, activation = "softmax")
130+
}
131+
132+
model_name <- "e2e_func_tune"
133+
on.exit(suppressMessages(remove_keras_spec(model_name)), add = TRUE)
134+
135+
create_keras_functional_spec(
136+
model_name = model_name,
137+
layer_blocks = list(
138+
main_input = input_block,
139+
# This block has a single input, so it can be repeated
140+
dense_path = inp_spec(dense_block, "main_input"),
141+
output = inp_spec(output_block_class, "dense_path")
142+
),
143+
mode = "classification"
144+
)
145+
146+
tune_spec <- e2e_func_tune(
147+
num_dense_path = tune(),
148+
dense_path_units = tune(),
149+
fit_epochs = 1
150+
) |>
151+
set_engine("keras")
152+
153+
rec <- recipe(Species ~ ., data = iris)
154+
tune_wf <- workflows::workflow(rec, tune_spec)
155+
156+
folds <- rsample::vfold_cv(iris, v = 2)
157+
params <- extract_parameter_set_dials(tune_wf) |>
158+
update(
159+
num_dense_path = num_terms(c(1, 2)),
160+
dense_path_units = hidden_units(c(4, 8))
161+
)
162+
grid <- grid_regular(params, levels = 2)
163+
control <- control_grid(save_pred = FALSE, verbose = FALSE)
164+
165+
tune_res <- try(
166+
tune_grid(
167+
tune_wf,
168+
resamples = folds,
169+
grid = grid,
170+
control = control
171+
),
172+
silent = TRUE
173+
)
174+
175+
if (inherits(tune_res, "try-error")) {
176+
testthat::skip(paste("Tuning failed with error:", as.character(tune_res)))
177+
}
178+
179+
expect_s3_class(tune_res, "tune_results")
180+
181+
metrics <- collect_metrics(tune_res)
182+
expect_s3_class(metrics, "tbl_df")
183+
expect_true(all(c("num_dense_path", "dense_path_units") %in% names(metrics)))
184+
})

0 commit comments

Comments
 (0)