Skip to content

Commit 2cc6ad4

Browse files
committed
Added a helper function for input specification
1 parent 9a3a388 commit 2cc6ad4

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed

R/create_keras_spec_helpers.R

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,117 @@ collect_spec_args <- function(
106106
list(all_args = all_args, parsnip_names = parsnip_names)
107107
}
108108

109+
#' Remap Layer Block Arguments for Model Specification
110+
#'
111+
#' @description
112+
#' Creates a wrapper function around a Keras layer block to rename its
113+
#' arguments. This is a powerful helper for defining the `layer_blocks` in
114+
#' [create_keras_functional_spec()] and [create_keras_sequential_spec()],
115+
#' allowing you to connect reusable blocks into a model graph without writing
116+
#' verbose anonymous functions.
117+
#'
118+
#' @details
119+
#' `inp_spec()` makes your model definitions cleaner and more readable. It
120+
#' handles the metaprogramming required to create a new function with the
121+
#' correct argument names, while preserving the original block's hyperparameters
122+
#' and their default values.
123+
#'
124+
#' The function supports two modes of operation based on `input_map`:
125+
#' 1. **Single Input Renaming**: If `input_map` is a single character string,
126+
#' the wrapper function renames the *first* argument of the `block` function
127+
#' to the provided string. This is the common case for blocks that take a
128+
#' single tensor input.
129+
#' 2. **Multiple Input Mapping**: If `input_map` is a named character vector,
130+
#' it provides an explicit mapping from new argument names (the names of the
131+
#' vector) to the original argument names in the `block` function (the values
132+
#' of the vector). This is used for blocks with multiple inputs, like a
133+
#' concatenation layer.
134+
#'
135+
#' @param block A function that defines a Keras layer or a set of layers. The
136+
#' first arguments should be the input tensor(s).
137+
#' @param input_map A single character string or a named character vector that
138+
#' specifies how to rename/remap the arguments of `block`.
139+
#'
140+
#' @return A new function (a closure) that wraps the `block` function with
141+
#' renamed arguments, ready to be used in a `layer_blocks` list.
142+
#'
143+
#' @export
144+
#' @examples
145+
#' \dontrun{
146+
#' # --- Example Blocks ---
147+
#' # A standard dense block with one input tensor and one hyperparameter.
148+
#' dense_block <- function(tensor, units = 16) {
149+
#' tensor |> keras3::layer_dense(units = units, activation = "relu")
150+
#' }
151+
#'
152+
#' # A block that takes two tensors as input.
153+
#' concat_block <- function(input_a, input_b) {
154+
#' keras3::layer_concatenate(list(input_a, input_b))
155+
#' }
156+
#'
157+
#' # An output block with one input.
158+
#' output_block <- function(tensor) {
159+
#' tensor |> keras3::layer_dense(units = 1)
160+
#' }
161+
#'
162+
#' # --- Usage ---
163+
#' layer_blocks <- list(
164+
#' main_input = keras3::layer_input,
165+
#' path_a = inp_spec(dense_block, "main_input"),
166+
#' path_b = inp_spec(dense_block, "main_input"),
167+
#' concatenated = inp_spec(
168+
#' concat_block,
169+
#' c(path_a = "input_a", path_b = "input_b")
170+
#' ),
171+
#' output = inp_spec(output_block, "concatenated")
172+
#' )
173+
#' }
174+
inp_spec <- function(block, input_map) {
175+
new_fun <- function() {}
176+
original_formals <- formals(block)
177+
original_names <- names(original_formals)
178+
179+
if (length(original_formals) == 0) {
180+
stop("The 'block' function must have at least one argument.")
181+
}
182+
183+
new_formals <- original_formals
184+
185+
if (
186+
is.character(input_map) &&
187+
is.null(names(input_map)) &&
188+
length(input_map) == 1
189+
) {
190+
# Case 1: Single string, rename first argument
191+
names(new_formals)[1] <- input_map
192+
} else if (is.character(input_map) && !is.null(names(input_map))) {
193+
# Case 2: Named vector for mapping
194+
if (!all(input_map %in% original_names)) {
195+
missing_args <- input_map[!input_map %in% original_names]
196+
stop(paste(
197+
"Argument(s)",
198+
paste(shQuote(missing_args), collapse = ", "),
199+
"not found in the block function."
200+
))
201+
}
202+
for (new_name in names(input_map)) {
203+
old_name <- input_map[[new_name]]
204+
names(new_formals)[original_names == old_name] <- new_name
205+
}
206+
} else {
207+
stop("`input_map` must be a single string or a named character vector.")
208+
}
209+
210+
formals(new_fun) <- new_formals
211+
212+
call_args <- lapply(names(new_formals), as.symbol)
213+
names(call_args) <- original_names
214+
215+
body(new_fun) <- as.call(c(list(as.symbol("block")), call_args))
216+
environment(new_fun) <- environment()
217+
new_fun
218+
}
219+
109220
#' Internal Implementation for Creating Keras Specifications
110221
#'
111222
#' @description

0 commit comments

Comments
 (0)