Skip to content

Commit 0e9de26

Browse files
committed
Adding tests for new functions
1 parent 80fb4ed commit 0e9de26

File tree

1 file changed

+155
-0
lines changed

1 file changed

+155
-0
lines changed
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# --- Test Data ---
2+
x_train <- as.matrix(iris[, 1:4])
3+
y_train <- iris$Species
4+
# --- Tests ---
5+
6+
test_that("compile_keras_grid works for sequential models", {
7+
skip_on_cran()
8+
9+
model_name <- "test_seq_spec_compile"
10+
on.exit(suppressMessages(remove_keras_spec(model_name)), add = TRUE)
11+
12+
create_keras_sequential_spec(
13+
model_name = model_name,
14+
mode = "classification",
15+
layer_blocks = list(
16+
dense = function(model, units = 32, activation = "relu") {
17+
if (is.null(model)) {
18+
keras3::keras_model_sequential(input_shape = 4) |>
19+
keras3::layer_dense(units = units, activation = activation)
20+
} else {
21+
model |> keras3::layer_dense(units = units, activation = activation)
22+
}
23+
},
24+
output = function(model, num_classes) {
25+
model |>
26+
keras3::layer_dense(units = num_classes, activation = "softmax")
27+
}
28+
)
29+
)
30+
31+
spec <- test_seq_spec_compile() |>
32+
set_engine("keras")
33+
34+
grid <- tibble::tibble(
35+
dense_units = c(16, 32),
36+
learn_rate = c(0.01, 0.001)
37+
)
38+
39+
results <- compile_keras_grid(spec, grid, x_train, y_train)
40+
41+
expect_s3_class(results, "tbl_df")
42+
expect_equal(nrow(results), 2)
43+
expect_true(all(
44+
c(
45+
"dense_units",
46+
"learn_rate",
47+
"compiled_model",
48+
"model_summary",
49+
"error"
50+
) %in%
51+
names(results)
52+
))
53+
expect_true(all(is.na(results$error)))
54+
expect_true(all(sapply(
55+
results$compiled_model,
56+
inherits,
57+
"keras.src.models.model.Model"
58+
)))
59+
})
60+
61+
test_that("compile_keras_grid works for functional models", {
62+
skip_on_cran()
63+
64+
model_name <- "test_func_spec_compile"
65+
on.exit(suppressMessages(remove_keras_spec(model_name)), add = TRUE)
66+
67+
create_keras_functional_spec(
68+
model_name = model_name,
69+
mode = "classification",
70+
layer_blocks = list(
71+
input = function(input_shape) {
72+
keras3::layer_input(shape = input_shape)
73+
},
74+
dense = function(input, units = 32) {
75+
input |> keras3::layer_dense(units = units, activation = "relu")
76+
},
77+
output = function(dense, num_classes) {
78+
dense |>
79+
keras3::layer_dense(units = num_classes, activation = "softmax")
80+
}
81+
)
82+
)
83+
84+
spec <- test_func_spec_compile() |>
85+
set_engine("keras")
86+
87+
grid <- tibble::tibble(
88+
dense_units = c(16, 32),
89+
learn_rate = c(0.01, 0.001)
90+
)
91+
92+
results <- compile_keras_grid(spec, grid, x_train, y_train)
93+
94+
expect_s3_class(results, "tbl_df")
95+
expect_equal(nrow(results), 2)
96+
expect_true(all(
97+
c(
98+
"dense_units",
99+
"learn_rate",
100+
"compiled_model",
101+
"model_summary",
102+
"error"
103+
) %in%
104+
names(results)
105+
))
106+
expect_true(all(is.na(results$error)))
107+
expect_true(all(sapply(
108+
results$compiled_model,
109+
inherits,
110+
"keras.src.models.model.Model"
111+
)))
112+
})
113+
114+
test_that("compile_keras_grid handles errors gracefully", {
115+
skip_on_cran()
116+
117+
model_name <- "test_bad_func_spec_compile"
118+
on.exit(suppressMessages(remove_keras_spec(model_name)), add = TRUE)
119+
120+
create_keras_functional_spec(
121+
model_name = model_name,
122+
mode = "classification",
123+
layer_blocks = list(
124+
input = function(input_shape) {
125+
keras3::layer_input(shape = input_shape)
126+
},
127+
dense1 = function(input, units = 32) {
128+
input |> keras3::layer_dense(units = units, activation = "relu")
129+
},
130+
dense2 = function(units = 16) {
131+
# Missing input tensor
132+
keras3::layer_dense(units = units, activation = "relu")
133+
},
134+
output = function(dense2, num_classes) {
135+
dense2 |>
136+
keras3::layer_dense(units = num_classes, activation = "softmax")
137+
}
138+
)
139+
)
140+
141+
spec <- test_bad_func_spec_compile() |>
142+
set_engine("keras")
143+
144+
grid <- tibble::tibble(dense1_units = 16)
145+
146+
expect_warning(
147+
results <- compile_keras_grid(spec, grid, x_train, y_train),
148+
"Block 'dense2' has no inputs from other blocks."
149+
)
150+
151+
expect_s3_class(results, "tbl_df")
152+
expect_equal(nrow(results), 1)
153+
expect_false(is.na(results$error[1]))
154+
expect_true(is.null(results$compiled_model[[1]]))
155+
})

0 commit comments

Comments
 (0)