Skip to content

Commit 6b491a5

Browse files
committed
Preparing spec_helpers for functional support
1 parent 3270c84 commit 6b491a5

File tree

1 file changed

+54
-20
lines changed

1 file changed

+54
-20
lines changed

R/create_keras_spec_helpers.R

Lines changed: 54 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,36 @@
11
#' Discover and Collect Model Specification Arguments
22
#'
3-
#' Introspects the provided layer block functions to generate a list of
4-
#' arguments for the new model specification. This includes arguments for
5-
#' block repetition (`num_*`), block-specific hyperparameters (`block_*`),
6-
#' and global training parameters.
3+
#' @description
4+
#' This internal helper introspects the user-provided `layer_blocks` functions
5+
#' to generate a complete list of arguments for the new model specification.
6+
#' The logic for discovering arguments differs for sequential and functional models.
7+
#'
8+
#' @details
9+
#' For **sequential models** (`functional = FALSE`):
10+
#' - It creates `num_{block_name}` arguments to control block repetition.
11+
#' - It inspects the arguments of each block function, skipping the first
12+
#' (assumed to be the `model` object), to find tunable hyperparameters.
13+
#'
14+
#' For **functional models** (`functional = TRUE`):
15+
#' - It does **not** create `num_{block_name}` arguments.
16+
#' - It inspects the arguments of each block function. Arguments whose names
17+
#' match other block names are considered graph connections (inputs) and are
18+
#' ignored. The remaining arguments are treated as tunable hyperparameters.
19+
#'
20+
#' In both cases, it also adds global training parameters (like `epochs`) and
21+
#' filters out special engine-supplied arguments (`input_shape`, `num_classes`).
722
#'
823
#' @param layer_blocks A named list of functions defining Keras layer blocks.
24+
#' @param functional A logical. If `TRUE`, uses discovery logic for the
25+
#' Functional API. If `FALSE`, uses logic for the Sequential API.
926
#' @param global_args A character vector of global arguments to add to the
1027
#' specification (e.g., "epochs").
1128
#' @return A list containing two elements:
12-
#' - `all_args`: A named list of arguments for the new function signature,
13-
#' initialized with `rlang::zap()`.
14-
#' - `parsnip_names`: A character vector of all argument names for `parsnip`.
29+
#'
1530
#' @noRd
1631
collect_spec_args <- function(
1732
layer_blocks,
33+
functional,
1834
global_args = c(
1935
"epochs",
2036
"batch_size",
@@ -36,23 +52,39 @@ collect_spec_args <- function(
3652
all_args <- list()
3753
parsnip_names <- character()
3854

55+
block_names <- names(layer_blocks)
56+
3957
# block repetition counts (e.g., num_dense)
40-
for (block in names(layer_blocks)) {
41-
num_name <- paste0("num_", block)
58+
for (block_name in block_names) {
59+
num_name <- paste0("num_", block_name)
4260
all_args[[num_name]] <- rlang::zap()
4361
parsnip_names <- c(parsnip_names, num_name)
4462
}
4563

4664
# These args are passed by the fit engine, not set by the user in the spec
4765
engine_args <- c("input_shape", "num_classes")
48-
# block-specific parameters (skip first 'model' formal)
49-
for (block in names(layer_blocks)) {
50-
fmls_to_process <- rlang::fn_fmls(layer_blocks[[block]])[-1]
51-
# Filter out arguments that are provided by the fitting engine
52-
for (arg in names(fmls_to_process[
53-
!names(fmls_to_process) %in% engine_args
54-
])) {
55-
full <- paste0(block, "_", arg)
66+
# Discover block-specific hyperparameters
67+
for (block_name in block_names) {
68+
block_fmls <- rlang::fn_fmls(layer_blocks[[block_name]])
69+
70+
if (isTRUE(functional)) {
71+
# For functional models, hyperparameters are arguments that are NOT
72+
# names of other blocks (which are graph connections).
73+
hyperparam_names <- setdiff(
74+
names(block_fmls),
75+
c(block_names, engine_args)
76+
)
77+
} else {
78+
# For sequential models, hyperparameters are all args except the first
79+
# ('model') and special engine args.
80+
fmls_to_process <- if (length(block_fmls) > 0) block_fmls[-1] else list()
81+
hyperparam_names <- names(fmls_to_process)[
82+
!names(fmls_to_process) %in% engine_args
83+
]
84+
}
85+
86+
for (arg in hyperparam_names) {
87+
full <- paste0(block_name, "_", arg)
5688
all_args[[full]] <- rlang::zap()
5789
parsnip_names <- c(parsnip_names, full)
5890
}
@@ -69,8 +101,10 @@ collect_spec_args <- function(
69101

70102
#' Internal Implementation for Creating Keras Specifications
71103
#'
72-
#' This is the core logic for both `create_keras_sequential_spec` and
73-
#' `create_keras_functional_spec`. It is not intended for direct use.
104+
#' @description
105+
#' This is the core implementation for both `create_keras_sequential_spec()` and
106+
#' `create_keras_functional_spec()`. It orchestrates the argument collection,
107+
#' function building, and `parsnip` registration steps.
74108
#'
75109
#' @inheritParams create_keras_sequential_spec
76110
#' @param functional A logical, if `TRUE`, registers the model to be fit with
@@ -85,7 +119,7 @@ create_keras_spec_impl <- function(
85119
functional,
86120
env
87121
) {
88-
args_info <- collect_spec_args(layer_blocks)
122+
args_info <- collect_spec_args(layer_blocks, functional = functional)
89123
spec_fun <- build_spec_function(
90124
model_name,
91125
mode,

0 commit comments

Comments
 (0)