@@ -106,6 +106,117 @@ collect_spec_args <- function(
106106 list (all_args = all_args , parsnip_names = parsnip_names )
107107}
108108
109+ # ' Remap Layer Block Arguments for Model Specification
110+ # '
111+ # ' @description
112+ # ' Creates a wrapper function around a Keras layer block to rename its
113+ # ' arguments. This is a powerful helper for defining the `layer_blocks` in
114+ # ' [create_keras_functional_spec()] and [create_keras_sequential_spec()],
115+ # ' allowing you to connect reusable blocks into a model graph without writing
116+ # ' verbose anonymous functions.
117+ # '
118+ # ' @details
119+ # ' `inp_spec()` makes your model definitions cleaner and more readable. It
120+ # ' handles the metaprogramming required to create a new function with the
121+ # ' correct argument names, while preserving the original block's hyperparameters
122+ # ' and their default values.
123+ # '
124+ # ' The function supports two modes of operation based on `input_map`:
125+ # ' 1. **Single Input Renaming**: If `input_map` is a single character string,
126+ # ' the wrapper function renames the *first* argument of the `block` function
127+ # ' to the provided string. This is the common case for blocks that take a
128+ # ' single tensor input.
129+ # ' 2. **Multiple Input Mapping**: If `input_map` is a named character vector,
130+ # ' it provides an explicit mapping from new argument names (the names of the
131+ # ' vector) to the original argument names in the `block` function (the values
132+ # ' of the vector). This is used for blocks with multiple inputs, like a
133+ # ' concatenation layer.
134+ # '
135+ # ' @param block A function that defines a Keras layer or a set of layers. The
136+ # ' first arguments should be the input tensor(s).
137+ # ' @param input_map A single character string or a named character vector that
138+ # ' specifies how to rename/remap the arguments of `block`.
139+ # '
140+ # ' @return A new function (a closure) that wraps the `block` function with
141+ # ' renamed arguments, ready to be used in a `layer_blocks` list.
142+ # '
143+ # ' @export
144+ # ' @examples
145+ # ' \dontrun{
146+ # ' # --- Example Blocks ---
147+ # ' # A standard dense block with one input tensor and one hyperparameter.
148+ # ' dense_block <- function(tensor, units = 16) {
149+ # ' tensor |> keras3::layer_dense(units = units, activation = "relu")
150+ # ' }
151+ # '
152+ # ' # A block that takes two tensors as input.
153+ # ' concat_block <- function(input_a, input_b) {
154+ # ' keras3::layer_concatenate(list(input_a, input_b))
155+ # ' }
156+ # '
157+ # ' # An output block with one input.
158+ # ' output_block <- function(tensor) {
159+ # ' tensor |> keras3::layer_dense(units = 1)
160+ # ' }
161+ # '
162+ # ' # --- Usage ---
163+ # ' layer_blocks <- list(
164+ # ' main_input = keras3::layer_input,
165+ # ' path_a = inp_spec(dense_block, "main_input"),
166+ # ' path_b = inp_spec(dense_block, "main_input"),
167+ # ' concatenated = inp_spec(
168+ # ' concat_block,
169+ # ' c(path_a = "input_a", path_b = "input_b")
170+ # ' ),
171+ # ' output = inp_spec(output_block, "concatenated")
172+ # ' )
173+ # ' }
174+ inp_spec <- function (block , input_map ) {
175+ new_fun <- function () {}
176+ original_formals <- formals(block )
177+ original_names <- names(original_formals )
178+
179+ if (length(original_formals ) == 0 ) {
180+ stop(" The 'block' function must have at least one argument." )
181+ }
182+
183+ new_formals <- original_formals
184+
185+ if (
186+ is.character(input_map ) &&
187+ is.null(names(input_map )) &&
188+ length(input_map ) == 1
189+ ) {
190+ # Case 1: Single string, rename first argument
191+ names(new_formals )[1 ] <- input_map
192+ } else if (is.character(input_map ) && ! is.null(names(input_map ))) {
193+ # Case 2: Named vector for mapping
194+ if (! all(input_map %in% original_names )) {
195+ missing_args <- input_map [! input_map %in% original_names ]
196+ stop(paste(
197+ " Argument(s)" ,
198+ paste(shQuote(missing_args ), collapse = " , " ),
199+ " not found in the block function."
200+ ))
201+ }
202+ for (new_name in names(input_map )) {
203+ old_name <- input_map [[new_name ]]
204+ names(new_formals )[original_names == old_name ] <- new_name
205+ }
206+ } else {
207+ stop(" `input_map` must be a single string or a named character vector." )
208+ }
209+
210+ formals(new_fun ) <- new_formals
211+
212+ call_args <- lapply(names(new_formals ), as.symbol )
213+ names(call_args ) <- original_names
214+
215+ body(new_fun ) <- as.call(c(list (as.symbol(" block" )), call_args ))
216+ environment(new_fun ) <- environment()
217+ new_fun
218+ }
219+
109220# ' Internal Implementation for Creating Keras Specifications
110221# '
111222# ' @description
0 commit comments