Skip to content

Commit abd3e85

Browse files
committed
Adding tests
1 parent 81b4193 commit abd3e85

File tree

9 files changed

+484
-5
lines changed

9 files changed

+484
-5
lines changed

tests/testthat.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
# * https://testthat.r-lib.org/articles/special-files.html
88

99
library(testthat)
10-
library(dpl)
10+
library(kerasnip)
1111

12-
test_check("dpl")
12+
test_check("kerasnip")

tests/testthat/helper-keras.R

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Helper to skip tests if Keras is not configured
2+
library(parsnip)
3+
library(recipes)
4+
library(workflows)
5+
library(modeldata)
6+
library(rsample)
7+
library(dials)
8+
library(tune)
9+
10+
skip_if_no_keras <- function() {
11+
testthat::skip_if_not_installed("keras3")
12+
13+
# is_keras_available() checks for the python 'keras' module and a backend.
14+
# This is the most reliable way to check for a working installation.
15+
# testthat::skip_if_not(
16+
# keras3::is_keras_available(),
17+
# "Keras 3 and a backend (e.g., tensorflow) are not available for testing"
18+
# )
19+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
test_that("E2E: Classification spec generation, fitting, and prediction works", {
2+
skip_if_no_keras()
3+
4+
input_block_class <- function(model, input_shape) {
5+
keras3::keras_model_sequential(input_shape = input_shape)
6+
}
7+
dense_block_class <- function(model, units = 16) {
8+
model |>
9+
keras3::layer_dense(units = units, activation = "relu")
10+
}
11+
output_block_class <- function(model, num_classes) {
12+
model |> keras3::layer_dense(units = num_classes, activation = "softmax")
13+
}
14+
15+
create_keras_spec(
16+
model_name = "e2e_mlp_class",
17+
layer_blocks = list(
18+
input = input_block_class,
19+
dense = dense_block_class,
20+
output = output_block_class
21+
),
22+
mode = "classification"
23+
)
24+
25+
spec <- e2e_mlp_class(
26+
num_dense = 2,
27+
dense_units = 8,
28+
epochs = 2
29+
) |>
30+
set_engine("keras")
31+
32+
# --- Multiclass test ---
33+
multi_data <- iris
34+
rec_multi <- recipe(Species ~ ., data = multi_data)
35+
wf_multi <- workflow(rec_multi, spec)
36+
37+
expect_no_error(fit_multi <- fit(wf_multi, data = multi_data))
38+
expect_s3_class(fit_multi, "workflow")
39+
40+
preds_class_multi <- predict(
41+
fit_multi,
42+
new_data = multi_data[1:5, ],
43+
type = "class"
44+
)
45+
expect_s3_class(preds_class_multi, "tbl_df")
46+
expect_equal(names(preds_class_multi), ".pred_class")
47+
expect_equal(nrow(preds_class_multi), 5)
48+
expect_equal(
49+
levels(preds_class_multi$.pred_class),
50+
levels(multi_data$Species)
51+
)
52+
53+
preds_prob_multi <- predict(
54+
fit_multi,
55+
new_data = multi_data[1:5, ],
56+
type = "prob"
57+
)
58+
expect_s3_class(preds_prob_multi, "tbl_df")
59+
expect_equal(
60+
names(preds_prob_multi),
61+
paste0(".pred_", levels(multi_data$Species))
62+
)
63+
expect_equal(nrow(preds_prob_multi), 5)
64+
expect_true(all(abs(rowSums(preds_prob_multi) - 1) < 1e-5))
65+
66+
# --- Binary test ---
67+
binary_data <- modeldata::two_class_dat
68+
rec_bin <- recipe(Class ~ ., data = binary_data)
69+
wf_bin <- workflow(rec_bin, spec)
70+
71+
expect_no_error(fit_bin <- fit(wf_bin, data = binary_data))
72+
expect_s3_class(fit_bin, "workflow")
73+
74+
preds_class_bin <- predict(
75+
fit_bin,
76+
new_data = binary_data[1:5, ],
77+
type = "class"
78+
)
79+
expect_s3_class(preds_class_bin, "tbl_df")
80+
expect_equal(names(preds_class_bin), ".pred_class")
81+
expect_equal(nrow(preds_class_bin), 5)
82+
expect_equal(levels(preds_class_bin$.pred_class), levels(binary_data$Class))
83+
84+
preds_prob_bin <- predict(
85+
fit_bin,
86+
new_data = binary_data[1:5, ],
87+
type = "prob"
88+
)
89+
expect_s3_class(preds_prob_bin, "tbl_df")
90+
expect_equal(names(preds_prob_bin), c(".pred_Class1", ".pred_Class2"))
91+
expect_equal(nrow(preds_prob_bin), 5)
92+
expect_true(all(abs(rowSums(preds_prob_bin) - 1) < 1e-5))
93+
})

tests/testthat/test-e2e-features.R

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
test_that("E2E: Customizing main arguments works", {
2+
skip_if_no_keras()
3+
4+
input_block_feat <- function(model, input_shape) {
5+
keras3::keras_model_sequential(input_shape = input_shape)
6+
}
7+
dense_block_feat <- function(model, units = 16) {
8+
model |> keras3::layer_dense(units = units, activation = "relu")
9+
}
10+
output_block_feat <- function(model) {
11+
model |> keras3::layer_dense(units = 1)
12+
}
13+
14+
create_keras_spec(
15+
model_name = "e2e_mlp_feat",
16+
layer_blocks = list(
17+
input = input_block_feat,
18+
dense = dense_block_feat,
19+
output = output_block_feat
20+
),
21+
mode = "regression"
22+
)
23+
24+
# Main arguments (like compile_*) should be set in the spec function,
25+
# not in set_engine().
26+
spec <- e2e_mlp_feat(
27+
epochs = 2,
28+
compile_optimizer = "sgd",
29+
compile_loss = "mae",
30+
compile_metrics = c("mean_squared_error")
31+
) |>
32+
parsnip::set_engine("keras")
33+
34+
# This should now run without the parsnip warning about removing arguments
35+
fit_obj <- NULL
36+
expect_no_warning(
37+
fit_obj <- parsnip::fit(spec, mpg ~ ., data = mtcars)
38+
)
39+
40+
# Also verify the arguments were correctly used during compilation
41+
keras_model <- fit_obj$fit$fit
42+
compiled_loss <- keras_model$loss
43+
compiled_optimizer <- tolower(keras_model$optimizer$name)
44+
compiled_metrics <- sapply(
45+
keras_model$metrics[[2]]$metrics,
46+
function(m) {
47+
m$name
48+
}
49+
)
50+
51+
# Keras might add suffixes or use different casings, so check flexibly
52+
expect_true(grepl("mae", compiled_loss))
53+
expect_true(grepl("sgd", tolower(compiled_optimizer)))
54+
expect_true("mean_squared_error" %in% compiled_metrics)
55+
})
56+
57+
test_that("E2E: Customizing fit arguments works", {
58+
skip_if_no_keras()
59+
60+
input_block_fit <- function(model, input_shape) {
61+
keras3::keras_model_sequential(input_shape = input_shape)
62+
}
63+
dense_block_fit <- function(model, units = 16) {
64+
model |> keras3::layer_dense(units = units, activation = "relu")
65+
}
66+
output_block_fit <- function(model) {
67+
model |> keras3::layer_dense(units = 1)
68+
}
69+
70+
create_keras_spec(
71+
model_name = "e2e_mlp_fit",
72+
layer_blocks = list(
73+
input = input_block_fit,
74+
dense = dense_block_fit,
75+
output = output_block_fit
76+
),
77+
mode = "regression"
78+
)
79+
80+
# Fit arguments (like validation_split, callbacks) should be set in the
81+
# spec function, not in set_engine().
82+
spec <- e2e_mlp_fit(
83+
fit_validation_split = 0.2,
84+
fit_callbacks = list(keras3::callback_early_stopping(patience = 1)),
85+
fit_epochs = 3,
86+
compile_metrics = "mean_squared_error"
87+
) |>
88+
parsnip::set_engine("keras")
89+
90+
# This will run without error if the arguments are passed correctly
91+
fit_obj <- NULL
92+
expect_no_error(fit_obj <- parsnip::fit(spec, mpg ~ ., data = mtcars))
93+
94+
# Check that the callback was used (model should stop early)
95+
expect_lt(length(fit_obj$fit$history$metrics$loss), 5)
96+
})
97+
98+
test_that("E2E: Setting num_blocks = 0 works", {
99+
skip_if_no_keras()
100+
101+
input_block_zero <- function(model, input_shape) {
102+
keras3::keras_model_sequential(input_shape = input_shape)
103+
}
104+
dense_block_zero <- function(model, units = 16) {
105+
model |> keras3::layer_dense(units = units, activation = "relu")
106+
}
107+
output_block_zero <- function(model) {
108+
model |> keras3::layer_dense(units = 1)
109+
}
110+
111+
create_keras_spec(
112+
model_name = "e2e_mlp_zero",
113+
layer_blocks = list(
114+
input = input_block_zero,
115+
dense = dense_block_zero,
116+
output = output_block_zero
117+
),
118+
mode = "regression"
119+
)
120+
121+
spec <- e2e_mlp_zero(num_dense = 0, epochs = 2) |>
122+
parsnip::set_engine("keras")
123+
# This should fit a model with only an input and output layer
124+
expect_no_error(parsnip::fit(spec, mpg ~ ., data = mtcars))
125+
})
126+
127+
test_that("E2E: Error handling for reserved names works", {
128+
bad_blocks <- list(
129+
compile = function(model) model, # "compile" is a reserved name
130+
dense = function(model, u = 1) model |> keras3::layer_dense(units = u)
131+
)
132+
133+
expect_error(
134+
create_keras_spec("bad_spec", bad_blocks),
135+
regexp = "`compile` and `optimizer` are protected names"
136+
)
137+
})
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
test_that("E2E: Multi-block model tuning works", {
2+
skip_if_no_keras()
3+
4+
input_block_mb <- function(model, input_shape) {
5+
keras3::keras_model_sequential(input_shape = input_shape)
6+
}
7+
8+
starting_layers <- function(model, layer1_units = 16, layer2_units = 32) {
9+
model |>
10+
keras3::layer_dense(units = layer1_units, activation = "relu") |>
11+
keras3::layer_dense(units = layer2_units, activation = "relu")
12+
}
13+
14+
ending_layers <- function(model, units = 32, dropout = 0.2) {
15+
model |>
16+
keras3::layer_dense(units = units, activation = "relu") |>
17+
keras3::layer_dropout(rate = dropout)
18+
}
19+
20+
output_block_mb <- function(model, num_classes) {
21+
model |> keras3::layer_dense(units = num_classes, activation = "softmax")
22+
}
23+
24+
create_keras_spec(
25+
model_name = "mb_mt",
26+
layer_blocks = list(
27+
input = input_block_mb,
28+
start = starting_layers,
29+
end = ending_layers,
30+
output = output_block_mb
31+
),
32+
mode = "classification"
33+
)
34+
35+
tune_spec <- mb_mt(
36+
num_start = tune(),
37+
start_layer1_units = tune(),
38+
start_layer2_units = tune(),
39+
end_units = tune(),
40+
epochs = 1
41+
) |>
42+
set_engine("keras")
43+
44+
rec <- recipe(Species ~ ., data = iris)
45+
wf <- workflow(rec) |>
46+
add_model(tune_spec)
47+
48+
folds <- rsample::vfold_cv(iris, v = 2)
49+
50+
params <- extract_parameter_set_dials(wf) |>
51+
update(
52+
num_start = dials::num_terms(c(1, 2)),
53+
start_layer1_units = dials::hidden_units(c(4, 8)),
54+
start_layer2_units = dials::hidden_units(c(8, 16)),
55+
end_units = dials::hidden_units(c(4, 8))
56+
)
57+
58+
grid <- grid_regular(params, levels = 2)
59+
control <- control_grid(
60+
save_pred = FALSE,
61+
verbose = FALSE,
62+
save_workflow = TRUE
63+
)
64+
65+
# Use a try block because tuning can sometimes fail for non-package reasons
66+
tune_res <- try(
67+
tune_grid(
68+
wf,
69+
resamples = folds,
70+
grid = grid,
71+
control = control
72+
),
73+
silent = TRUE
74+
)
75+
76+
if (inherits(tune_res, "try-error")) {
77+
testthat::skip(paste("Tuning failed with error:", as.character(tune_res)))
78+
}
79+
80+
expect_s3_class(tune_res, "tune_results")
81+
82+
metrics <- collect_metrics(tune_res)
83+
expect_s3_class(metrics, "tbl_df")
84+
expect_true(all(
85+
c("num_start", "start_layer1_units", "start_layer2_units", "end_units") %in%
86+
names(metrics)
87+
))
88+
89+
expect_no_error(
90+
best_fit <- tune::fit_best(tune_res)
91+
)
92+
expect_s3_class(best_fit, "workflow")
93+
})

0 commit comments

Comments
 (0)