11# ' Discover and Collect Model Specification Arguments
22# '
3- # ' Introspects the provided layer block functions to generate a list of
4- # ' arguments for the new model specification. This includes arguments for
5- # ' block repetition (`num_*`), block-specific hyperparameters (`block_*`),
6- # ' and global training parameters.
3+ # ' @description
4+ # ' This internal helper introspects the user-provided `layer_blocks` functions
5+ # ' to generate a complete list of arguments for the new model specification.
6+ # ' The logic for discovering arguments differs for sequential and functional models.
7+ # '
8+ # ' @details
9+ # ' For **sequential models** (`functional = FALSE`):
10+ # ' - It creates `num_{block_name}` arguments to control block repetition.
11+ # ' - It inspects the arguments of each block function, skipping the first
12+ # ' (assumed to be the `model` object), to find tunable hyperparameters.
13+ # '
14+ # ' For **functional models** (`functional = TRUE`):
15+ # ' - It does **not** create `num_{block_name}` arguments.
16+ # ' - It inspects the arguments of each block function. Arguments whose names
17+ # ' match other block names are considered graph connections (inputs) and are
18+ # ' ignored. The remaining arguments are treated as tunable hyperparameters.
19+ # '
20+ # ' In both cases, it also adds global training parameters (like `epochs`) and
21+ # ' filters out special engine-supplied arguments (`input_shape`, `num_classes`).
722# '
823# ' @param layer_blocks A named list of functions defining Keras layer blocks.
24+ # ' @param functional A logical. If `TRUE`, uses discovery logic for the
25+ # ' Functional API. If `FALSE`, uses logic for the Sequential API.
926# ' @param global_args A character vector of global arguments to add to the
1027# ' specification (e.g., "epochs").
1128# ' @return A list containing two elements:
12- # ' - `all_args`: A named list of arguments for the new function signature,
13- # ' initialized with `rlang::zap()`.
14- # ' - `parsnip_names`: A character vector of all argument names for `parsnip`.
29+ # '
1530# ' @noRd
1631collect_spec_args <- function (
1732 layer_blocks ,
33+ functional ,
1834 global_args = c(
1935 " epochs" ,
2036 " batch_size" ,
@@ -36,23 +52,39 @@ collect_spec_args <- function(
3652 all_args <- list ()
3753 parsnip_names <- character ()
3854
55+ block_names <- names(layer_blocks )
56+
3957 # block repetition counts (e.g., num_dense)
40- for (block in names( layer_blocks ) ) {
41- num_name <- paste0(" num_" , block )
58+ for (block_name in block_names ) {
59+ num_name <- paste0(" num_" , block_name )
4260 all_args [[num_name ]] <- rlang :: zap()
4361 parsnip_names <- c(parsnip_names , num_name )
4462 }
4563
4664 # These args are passed by the fit engine, not set by the user in the spec
4765 engine_args <- c(" input_shape" , " num_classes" )
48- # block-specific parameters (skip first 'model' formal)
49- for (block in names(layer_blocks )) {
50- fmls_to_process <- rlang :: fn_fmls(layer_blocks [[block ]])[- 1 ]
51- # Filter out arguments that are provided by the fitting engine
52- for (arg in names(fmls_to_process [
53- ! names(fmls_to_process ) %in% engine_args
54- ])) {
55- full <- paste0(block , " _" , arg )
66+ # Discover block-specific hyperparameters
67+ for (block_name in block_names ) {
68+ block_fmls <- rlang :: fn_fmls(layer_blocks [[block_name ]])
69+
70+ if (isTRUE(functional )) {
71+ # For functional models, hyperparameters are arguments that are NOT
72+ # names of other blocks (which are graph connections).
73+ hyperparam_names <- setdiff(
74+ names(block_fmls ),
75+ c(block_names , engine_args )
76+ )
77+ } else {
78+ # For sequential models, hyperparameters are all args except the first
79+ # ('model') and special engine args.
80+ fmls_to_process <- if (length(block_fmls ) > 0 ) block_fmls [- 1 ] else list ()
81+ hyperparam_names <- names(fmls_to_process )[
82+ ! names(fmls_to_process ) %in% engine_args
83+ ]
84+ }
85+
86+ for (arg in hyperparam_names ) {
87+ full <- paste0(block_name , " _" , arg )
5688 all_args [[full ]] <- rlang :: zap()
5789 parsnip_names <- c(parsnip_names , full )
5890 }
@@ -69,8 +101,10 @@ collect_spec_args <- function(
69101
70102# ' Internal Implementation for Creating Keras Specifications
71103# '
72- # ' This is the core logic for both `create_keras_sequential_spec` and
73- # ' `create_keras_functional_spec`. It is not intended for direct use.
104+ # ' @description
105+ # ' This is the core implementation for both `create_keras_sequential_spec()` and
106+ # ' `create_keras_functional_spec()`. It orchestrates the argument collection,
107+ # ' function building, and `parsnip` registration steps.
74108# '
75109# ' @inheritParams create_keras_sequential_spec
76110# ' @param functional A logical, if `TRUE`, registers the model to be fit with
@@ -85,7 +119,7 @@ create_keras_spec_impl <- function(
85119 functional ,
86120 env
87121) {
88- args_info <- collect_spec_args(layer_blocks )
122+ args_info <- collect_spec_args(layer_blocks , functional = function al )
89123 spec_fun <- build_spec_function(
90124 model_name ,
91125 mode ,
0 commit comments