Skip to content

Commit 3b81d23

Browse files
committed
Fixing detected issue witth num_layer = 0 and making tests more robust
1 parent a1622be commit 3b81d23

File tree

4 files changed

+89
-7
lines changed

4 files changed

+89
-7
lines changed

R/generic_functional_fit.R

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,15 @@ generic_functional_fit <- function(
128128

129129
# --- Get Repetition Count ---
130130
num_repeats_arg <- paste0("num_", block_name)
131-
num_repeats <- all_args[[num_repeats_arg]] %||% 1
131+
num_repeats_val <- all_args[[num_repeats_arg]]
132+
133+
# If num_repeats_val is NULL or zapped, default to 1.
134+
# Otherwise, use the value provided by the user.
135+
if (is.null(num_repeats_val) || inherits(num_repeats_val, "rlang_zap")) {
136+
num_repeats <- 1
137+
} else {
138+
num_repeats <- as.integer(num_repeats_val)
139+
}
132140

133141
# --- Get Hyperparameters for this block ---
134142
# Hyperparameters are formals that are NOT other block names (graph connections)

R/generic_sequential_fit.R

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,14 @@ generic_sequential_fit <- function(
118118

119119
num_repeats_arg <- paste0("num_", block_name)
120120
num_repeats_val <- all_args[[num_repeats_arg]]
121-
num_repeats <- num_repeats_val %||% 1
121+
122+
# If num_repeats_val is NULL or zapped, default to 1.
123+
# Otherwise, use the value provided by the user.
124+
if (is.null(num_repeats_val) || inherits(num_repeats_val, "rlang_zap")) {
125+
num_repeats <- 1
126+
} else {
127+
num_repeats <- as.integer(num_repeats_val)
128+
}
122129

123130
# Get the arguments for this specific block from `...`
124131
block_arg_names <- names(block_fmls)[-1] # Exclude 'model'

tests/testthat/test-e2e-features.R

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,19 @@ test_that("E2E: Customizing fit arguments works", {
102102
expect_lt(length(fit_obj$fit$history$metrics$loss), 5)
103103
})
104104

105-
test_that("E2E: Setting num_blocks = 0 works", {
105+
test_that("E2E: Setting num_blocks = 0 works for sequential models", {
106106
skip_if_no_keras()
107107

108108
input_block_zero <- function(model, input_shape) {
109109
keras3::keras_model_sequential(input_shape = input_shape)
110110
}
111111
dense_block_zero <- function(model, units = 16) {
112-
model |> keras3::layer_dense(units = units, activation = "relu")
112+
model |>
113+
keras3::layer_dense(
114+
units = units,
115+
activation = "relu",
116+
name = "i_should_not_exist"
117+
)
113118
}
114119
output_block_zero <- function(model) {
115120
model |> keras3::layer_dense(units = 1)
@@ -128,10 +133,18 @@ test_that("E2E: Setting num_blocks = 0 works", {
128133
mode = "regression"
129134
)
130135

131-
spec <- e2e_mlp_zero(num_dense = 0, fit_epochs = 2) |>
136+
spec <- e2e_mlp_zero(num_dense = 0, fit_epochs = 1) |>
132137
parsnip::set_engine("keras")
133-
# This should fit a model with only an input and output layer
134-
expect_no_error(parsnip::fit(spec, mpg ~ ., data = mtcars))
138+
139+
fit_obj <- parsnip::fit(spec, mpg ~ ., data = mtcars)
140+
141+
# Check that the dense layer is NOT in the model
142+
keras_model <- fit_obj |> extract_keras_summary()
143+
expect_equal(length(keras_model$layers), 1) # Output layers only
144+
145+
# Check layer names explicitly
146+
layer_names <- sapply(keras_model$layers, function(l) l$name)
147+
expect_false("i_should_not_exist" %in% layer_names)
135148
})
136149

137150
test_that("E2E: Error handling for reserved names works", {

tests/testthat/test-e2e-functional.R

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,57 @@ test_that("E2E: Functional spec tuning (including repetition) works", {
182182
expect_s3_class(metrics, "tbl_df")
183183
expect_true(all(c("num_dense_path", "dense_path_units") %in% names(metrics)))
184184
})
185+
186+
test_that("E2E: Block repetition works for functional models", {
187+
skip_if_no_keras()
188+
189+
input_block <- function(input_shape) keras3::layer_input(shape = input_shape)
190+
dense_block <- function(tensor, units = 8) {
191+
tensor |> keras3::layer_dense(units = units, activation = "relu")
192+
}
193+
output_block <- function(tensor) keras3::layer_dense(tensor, units = 1)
194+
195+
model_name <- "e2e_func_repeat"
196+
on.exit(suppressMessages(remove_keras_spec(model_name)), add = TRUE)
197+
198+
create_keras_functional_spec(
199+
model_name = model_name,
200+
layer_blocks = list(
201+
main_input = input_block,
202+
dense_path = inp_spec(dense_block, "main_input"),
203+
output = inp_spec(output_block, "dense_path")
204+
),
205+
mode = "regression"
206+
)
207+
208+
# --- Test with 1 repetition ---
209+
spec_1 <- e2e_func_repeat(num_dense_path = 1, fit_epochs = 1) |>
210+
set_engine("keras")
211+
fit_1 <- fit(spec_1, mpg ~ ., data = mtcars)
212+
model_1_layers <- fit_1 |>
213+
extract_keras_summary() |>
214+
pluck("layers")
215+
216+
# Expect 3 layers: Input, Dense, Output
217+
expect_equal(length(model_1_layers), 3)
218+
219+
# --- Test with 2 repetitions ---
220+
spec_2 <- e2e_func_repeat(num_dense_path = 2, fit_epochs = 1) |>
221+
set_engine("keras")
222+
fit_2 <- fit(spec_2, mpg ~ ., data = mtcars)
223+
model_2_layers <- fit_2 |>
224+
extract_keras_summary() |>
225+
pluck("layers")
226+
# Expect 4 layers: Input, Dense, Dense, Output
227+
expect_equal(length(model_2_layers), 4)
228+
229+
# --- Test with 0 repetitions ---
230+
spec_3 <- e2e_func_repeat(num_dense_path = 0, fit_epochs = 1) |>
231+
set_engine("keras")
232+
fit_3 <- fit(spec_3, mpg ~ ., data = mtcars)
233+
model_3_layers <- fit_3 |>
234+
extract_keras_summary() |>
235+
pluck("layers")
236+
# Expect 2 layers: Input, Output
237+
expect_equal(length(model_3_layers), 2)
238+
})

0 commit comments

Comments
 (0)