Skip to content

Commit 9a3a388

Browse files
committed
Adding support for functional API
1 parent 22065f4 commit 9a3a388

File tree

7 files changed

+633
-5
lines changed

7 files changed

+633
-5
lines changed

NAMESPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Generated by roxygen2: do not edit by hand
22

33
export(create_keras_functional_spec)
4-
export(create_keras_spec)
4+
export(create_keras_sequential_spec)
55
export(generic_functional_fit)
66
export(generic_sequential_fit)
77
export(keras_losses)

R/create_keras_functional_spec.R

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#' Create a Custom Keras Functional API Model Specification for Tidymodels
2+
#'
3+
#' This function acts as a factory to generate a new `parsnip` model
4+
#' specification based on user-defined blocks of Keras layers using the
5+
#' Functional API. This allows for creating complex, tunable architectures
6+
#' with non-linear topologies that integrate seamlessly with the `tidymodels`
7+
#' ecosystem.
8+
#'
9+
#' @param model_name A character string for the name of the new model
10+
#' specification function (e.g., "custom_resnet"). This should be a valid R
11+
#' function name.
12+
#' @param layer_blocks A named list of functions where each function defines a
13+
#' "block" (a node) in the model graph. The list names are crucial as they
14+
#' define the names of the nodes. The arguments of each function define how
15+
#' the nodes are connected. See the "Model Graph Connectivity" section for
16+
#' details.
17+
#' @param mode A character string, either "regression" or "classification".
18+
#' @param ... Reserved for future use. Currently not used.
19+
#' @param env The environment in which to create the new model specification
20+
#' function and its associated `update()` method. Defaults to the calling
21+
#' environment (`parent.frame()`).
22+
#'
23+
#' @details
24+
#' This function generates all the boilerplate needed to create a custom,
25+
#' tunable `parsnip` model specification that uses the Keras Functional API.
26+
#' This is ideal for models with complex, non-linear topologies, such as
27+
#' networks with multiple inputs/outputs or residual connections.
28+
#'
29+
#' The function inspects the arguments of your `layer_blocks` functions and
30+
#' makes them available as tunable parameters in the generated model
31+
#' specification, prefixed with the block's name (e.g., `dense_units`).
32+
#' Common training parameters such as `epochs` and `learn_rate` are also added.
33+
#'
34+
#' @section Model Graph Connectivity:
35+
#' `kerasnip` builds the model's directed acyclic graph by inspecting the
36+
#' arguments of each function in the `layer_blocks` list. The connection logic
37+
#' is as follows:
38+
#'
39+
#' 1. The **names of the elements** in the `layer_blocks` list define the names
40+
#' of the nodes in your graph (e.g., `main_input`, `dense_path`, `output`).
41+
#' 2. The **names of the arguments** in each block function specify its inputs.
42+
#' A block function like `my_block <- function(input_a, input_b, ...)`
43+
#' declares that it needs input from the nodes named `input_a` and `input_b`.
44+
#' `kerasnip` will automatically supply the output tensors from those nodes
45+
#' when calling `my_block`.
46+
#'
47+
#' There are two special requirements:
48+
#' * **Input Block**: The first block in the list is treated as the input
49+
#' node. Its function should not take other blocks as input, but it can have
50+
#' an `input_shape` argument, which is supplied automatically during fitting.
51+
#' * **Output Block**: Exactly one block must be named `"output"`. The tensor
52+
#' returned by this block is used as the final output of the Keras model.
53+
#'
54+
#' A key feature is the automatic creation of `num_{block_name}` arguments
55+
#' (e.g., `num_dense_path`). This allows you to control how many times a block
56+
#' is repeated, making it easy to tune the depth of your network. A block can
57+
#' only be repeated if it has exactly one input from another block in the graph.
58+
#'
59+
#' The new model specification function and its `update()` method are created
60+
#' in the environment specified by the `env` argument.
61+
#'
62+
#' @importFrom rlang enquos dots_list arg_match env_poke
63+
#' @importFrom parsnip update_dot_check
64+
#'
65+
#' @return Invisibly returns `NULL`. Its primary side effect is to create a
66+
#' new model specification function (e.g., `custom_resnet()`) in the
67+
#' specified environment and register the model with `parsnip` so it can be
68+
#' used within the `tidymodels` framework.
69+
#'
70+
#' @seealso [remove_keras_spec()], [parsnip::new_model_spec()],
71+
#' [create_keras_sequential_spec()]
72+
#'
73+
#' @export
74+
#' @examples
75+
#' \dontrun{
76+
#' if (requireNamespace("keras3", quietly = TRUE)) {
77+
#' library(keras3)
78+
#' library(parsnip)
79+
#'
80+
#' # 1. Define block functions. These are the building blocks of our model.
81+
#' # An input block that receives the data's shape automatically.
82+
#' input_block <- function(input_shape) layer_input(shape = input_shape)
83+
#'
84+
#' # A dense block with a tunable `units` parameter.
85+
#' dense_block <- function(tensor, units) {
86+
#' tensor |> layer_dense(units = units, activation = "relu")
87+
#' }
88+
#'
89+
#' # A block that adds two tensors together (for the residual connection).
90+
#' add_block <- function(input_a, input_b) layer_add(list(input_a, input_b))
91+
#'
92+
#' # An output block for regression.
93+
#' output_block_reg <- function(tensor) layer_dense(tensor, units = 1)
94+
#'
95+
#' # 2. Create the spec. The `layer_blocks` list defines the graph.
96+
#' create_keras_functional_spec(
97+
#' model_name = "my_resnet_spec",
98+
#' layer_blocks = list(
99+
#' # The names of list elements are the node names.
100+
#' main_input = input_block,
101+
#'
102+
#' # The argument `main_input` connects this block to the input node.
103+
#' dense_path = function(main_input, units = 32) dense_block(main_input, units),
104+
#'
105+
#' # This block's arguments connect it to the original input AND the dense layer.
106+
#' add_residual = function(main_input, dense_path) add_block(main_input, dense_path),
107+
#'
108+
#' # This block must be named 'output'. It connects to the residual add layer.
109+
#' output = function(add_residual) output_block_reg(add_residual)
110+
#' ),
111+
#' mode = "regression"
112+
#' )
113+
#'
114+
#' # 3. Use the newly created specification function!
115+
#' # The `dense_path_units` argument was created automatically.
116+
#' model_spec <- my_resnet_spec(dense_path_units = 64, epochs = 10)
117+
#'
118+
#' # You could also tune the number of dense layers since it has a single input:
119+
#' # model_spec <- my_resnet_spec(num_dense_path = 2, dense_path_units = 32)
120+
#'
121+
#' print(model_spec)
122+
#' # tune::tunable(model_spec)
123+
#' }
124+
#' }
125+
create_keras_functional_spec <- function(
126+
model_name,
127+
layer_blocks,
128+
mode = c("regression", "classification"),
129+
...,
130+
env = parent.frame()
131+
) {
132+
mode <- rlang::arg_match(mode)
133+
# 1. Argument Validation
134+
create_keras_spec_impl(
135+
model_name,
136+
layer_blocks,
137+
mode,
138+
functional = TRUE,
139+
env
140+
)
141+
}

R/generate_roxygen_docs.R

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,30 @@ generate_roxygen_docs <- function(
8787
block_params <- setdiff(block_params, "learn_rate")
8888
}
8989
if (length(block_params) > 0) {
90+
# Sort block names by length descending to handle overlapping names
91+
# (e.g., "dense" and "dense_layer")
92+
sorted_block_names <- names(layer_blocks)[
93+
order(nchar(names(layer_blocks)), decreasing = TRUE)
94+
]
95+
9096
param_docs <- c(
9197
param_docs,
9298
purrr::map_chr(block_params, function(p) {
93-
parts <- strsplit(p, "_", fixed = TRUE)[[1]]
94-
block_name <- parts[1]
95-
param_name <- paste(parts[-1], collapse = "_")
99+
# Find the block name that is a prefix for this parameter.
100+
# The `Find` function returns the first match, and since we sorted
101+
# block names by length, it will find the longest possible match.
102+
block_name <- Find(
103+
function(bn) startsWith(p, paste0(bn, "_")),
104+
sorted_block_names
105+
)
106+
107+
if (is.null(block_name)) {
108+
# This should not happen if collect_spec_args is correct, but as a
109+
# fallback, we avoid an error.
110+
return(paste0("@param ", p, " A model parameter."))
111+
}
112+
113+
param_name <- sub(paste0(block_name, "_"), "", p, fixed = TRUE)
96114
block_fn <- layer_blocks[[block_name]]
97115
default_val <- rlang::fn_fmls(block_fn)[[param_name]]
98116
default_str <- if (

0 commit comments

Comments
 (0)