Skip to content

Commit 3182a25

Browse files
authored
Merge pull request #6 from davidrsch/functional-api-support
Addin support for functional appi closses #2es #
2 parents d474e3d + ade722e commit 3182a25

39 files changed

+2845
-1020
lines changed

NAMESPACE

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# Generated by roxygen2: do not edit by hand
22

3-
export(create_keras_spec)
4-
export(generic_keras_fit_impl)
3+
export(create_keras_functional_spec)
4+
export(create_keras_sequential_spec)
5+
export(generic_functional_fit)
6+
export(generic_sequential_fit)
7+
export(inp_spec)
58
export(keras_losses)
69
export(keras_metrics)
710
export(keras_optimizers)

R/build_spec_function.R

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#' Build the Model Specification Function
2+
#'
3+
#' @description
4+
#' This internal helper uses metaprogramming to construct a complete R function
5+
#' that acts as a `parsnip` model specification (e.g., `my_mlp()`).
6+
#'
7+
#' @details
8+
#' The process involves three main steps:
9+
#' 1. **Function Body Construction**: An expression for the function body is
10+
#' created. This body uses `rlang::enquo()` and `rlang::enquos()` to
11+
#' capture all user-provided arguments (both named and via `...`) into a
12+
#' list of quosures. This list is then passed to `parsnip::new_model_spec()`.
13+
#' 2. **Function Signature Construction**: A formal argument list is created
14+
#' from `all_args`, and `...` is added to allow passthrough arguments.
15+
#' `rlang::new_function()` combines the signature and body into a new
16+
#' function object.
17+
#' 3. **Documentation Attachment**: `generate_roxygen_docs()` creates a
18+
#' comprehensive Roxygen comment block as a string, which is then attached
19+
#' to the new function using `comment()`.
20+
#'
21+
#' @param model_name The name of the model specification function to create (e.g., "my_mlp").
22+
#' @param mode The model mode ("regression" or "classification").
23+
#' @param all_args A named list of formal arguments for the new function's
24+
#' signature, as generated by `collect_spec_args()`. The values are typically
25+
#' `rlang::missing_arg()` or `rlang::zap()`.
26+
#' @param parsnip_names A character vector of all argument names that should be
27+
#' captured as quosures and passed to `parsnip::new_model_spec()`.
28+
#' @param layer_blocks The user-provided list of layer block functions. This is
29+
#' passed directly to `generate_roxygen_docs()` to create documentation for
30+
#' block-specific parameters.
31+
#' @param functional A logical indicating if the model is functional
32+
#' (for `create_keras_functional_spec()`) or sequential. This is passed to
33+
#' `generate_roxygen_docs()` to tailor the documentation.
34+
#' @return A new function object with attached Roxygen comments, ready to be
35+
#' placed in the user's environment.
36+
#' @noRd
37+
build_spec_function <- function(
38+
model_name,
39+
mode,
40+
all_args,
41+
parsnip_names,
42+
layer_blocks,
43+
functional = FALSE
44+
) {
45+
quos_exprs <- purrr::map(
46+
parsnip_names,
47+
~ rlang::expr(rlang::enquo(!!rlang::sym(.x)))
48+
)
49+
names(quos_exprs) <- parsnip_names
50+
51+
body <- rlang::expr({
52+
# Capture both explicit args and ... to pass to the fit impl
53+
# Named arguments are captured into a list of quosures.
54+
main_args <- rlang::list2(!!!quos_exprs)
55+
# ... arguments are captured into a separate list of quosures.
56+
dot_args <- rlang::enquos(...)
57+
args <- c(main_args, dot_args)
58+
parsnip::new_model_spec(
59+
!!model_name,
60+
args = args,
61+
eng_args = NULL,
62+
mode = !!mode,
63+
method = NULL,
64+
engine = NULL
65+
)
66+
})
67+
68+
# Add ... to the function signature to capture any other compile arguments
69+
fn_args <- c(all_args, list(... = rlang::missing_arg()))
70+
71+
fn <- rlang::new_function(args = fn_args, body = body)
72+
73+
docs <- generate_roxygen_docs(
74+
model_name,
75+
layer_blocks,
76+
all_args,
77+
functional = functional
78+
)
79+
comment(fn) <- docs
80+
fn
81+
}

R/create_keras_functional_spec.R

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#' Create a Custom Keras Functional API Model Specification for Tidymodels
2+
#'
3+
#' This function acts as a factory to generate a new `parsnip` model
4+
#' specification based on user-defined blocks of Keras layers using the
5+
#' Functional API. This allows for creating complex, tunable architectures
6+
#' with non-linear topologies that integrate seamlessly with the `tidymodels`
7+
#' ecosystem.
8+
#'
9+
#' @param model_name A character string for the name of the new model
10+
#' specification function (e.g., "custom_resnet"). This should be a valid R
11+
#' function name.
12+
#' @param layer_blocks A named list of functions where each function defines a
13+
#' "block" (a node) in the model graph. The list names are crucial as they
14+
#' define the names of the nodes. The arguments of each function define how
15+
#' the nodes are connected. See the "Model Graph Connectivity" section for
16+
#' details.
17+
#' @param mode A character string, either "regression" or "classification".
18+
#' @param ... Reserved for future use. Currently not used.
19+
#' @param env The environment in which to create the new model specification
20+
#' function and its associated `update()` method. Defaults to the calling
21+
#' environment (`parent.frame()`).
22+
#'
23+
#' @details
24+
#' This function generates all the boilerplate needed to create a custom,
25+
#' tunable `parsnip` model specification that uses the Keras Functional API.
26+
#' This is ideal for models with complex, non-linear topologies, such as
27+
#' networks with multiple inputs/outputs or residual connections.
28+
#'
29+
#' The function inspects the arguments of your `layer_blocks` functions and
30+
#' makes them available as tunable parameters in the generated model
31+
#' specification, prefixed with the block's name (e.g., `dense_units`).
32+
#' Common training parameters such as `epochs` and `learn_rate` are also added.
33+
#'
34+
#' @section Model Graph Connectivity:
35+
#' `kerasnip` builds the model's directed acyclic graph by inspecting the
36+
#' arguments of each function in the `layer_blocks` list. The connection logic
37+
#' is as follows:
38+
#'
39+
#' 1. The **names of the elements** in the `layer_blocks` list define the names
40+
#' of the nodes in your graph (e.g., `main_input`, `dense_path`, `output`).
41+
#' 2. The **names of the arguments** in each block function specify its inputs.
42+
#' A block function like `my_block <- function(input_a, input_b, ...)`
43+
#' declares that it needs input from the nodes named `input_a` and `input_b`.
44+
#' `kerasnip` will automatically supply the output tensors from those nodes
45+
#' when calling `my_block`.
46+
#'
47+
#' There are two special requirements:
48+
#' * **Input Block**: The first block in the list is treated as the input
49+
#' node. Its function should not take other blocks as input, but it can have
50+
#' an `input_shape` argument, which is supplied automatically during fitting.
51+
#' * **Output Block**: Exactly one block must be named `"output"`. The tensor
52+
#' returned by this block is used as the final output of the Keras model.
53+
#'
54+
#' A key feature is the automatic creation of `num_{block_name}` arguments
55+
#' (e.g., `num_dense_path`). This allows you to control how many times a block
56+
#' is repeated, making it easy to tune the depth of your network. A block can
57+
#' only be repeated if it has exactly one input from another block in the graph.
58+
#'
59+
#' The new model specification function and its `update()` method are created
60+
#' in the environment specified by the `env` argument.
61+
#'
62+
#' @importFrom rlang enquos dots_list arg_match env_poke
63+
#' @importFrom parsnip update_dot_check
64+
#'
65+
#' @return Invisibly returns `NULL`. Its primary side effect is to create a
66+
#' new model specification function (e.g., `custom_resnet()`) in the
67+
#' specified environment and register the model with `parsnip` so it can be
68+
#' used within the `tidymodels` framework.
69+
#'
70+
#' @seealso [remove_keras_spec()], [parsnip::new_model_spec()],
71+
#' [create_keras_sequential_spec()]
72+
#'
73+
#' @export
74+
#' @examples
75+
#' \dontrun{
76+
#' if (requireNamespace("keras3", quietly = TRUE)) {
77+
#' library(keras3)
78+
#' library(parsnip)
79+
#'
80+
#' # 1. Define block functions. These are the building blocks of our model.
81+
#' # An input block that receives the data's shape automatically.
82+
#' input_block <- function(input_shape) layer_input(shape = input_shape)
83+
#'
84+
#' # A dense block with a tunable `units` parameter.
85+
#' dense_block <- function(tensor, units) {
86+
#' tensor |> layer_dense(units = units, activation = "relu")
87+
#' }
88+
#'
89+
#' # A block that adds two tensors together (for the residual connection).
90+
#' add_block <- function(input_a, input_b) layer_add(list(input_a, input_b))
91+
#'
92+
#' # An output block for regression.
93+
#' output_block_reg <- function(tensor) layer_dense(tensor, units = 1)
94+
#'
95+
#' # 2. Create the spec. The `layer_blocks` list defines the graph.
96+
#' create_keras_functional_spec(
97+
#' model_name = "my_resnet_spec",
98+
#' layer_blocks = list(
99+
#' # The names of list elements are the node names.
100+
#' main_input = input_block,
101+
#'
102+
#' # The argument `main_input` connects this block to the input node.
103+
#' dense_path = function(main_input, units = 32) dense_block(main_input, units),
104+
#'
105+
#' # This block's arguments connect it to the original input AND the dense layer.
106+
#' add_residual = function(main_input, dense_path) add_block(main_input, dense_path),
107+
#'
108+
#' # This block must be named 'output'. It connects to the residual add layer.
109+
#' output = function(add_residual) output_block_reg(add_residual)
110+
#' ),
111+
#' mode = "regression"
112+
#' )
113+
#'
114+
#' # 3. Use the newly created specification function!
115+
#' # The `dense_path_units` argument was created automatically.
116+
#' model_spec <- my_resnet_spec(dense_path_units = 64, epochs = 10)
117+
#'
118+
#' # You could also tune the number of dense layers since it has a single input:
119+
#' # model_spec <- my_resnet_spec(num_dense_path = 2, dense_path_units = 32)
120+
#'
121+
#' print(model_spec)
122+
#' # tune::tunable(model_spec)
123+
#' }
124+
#' }
125+
create_keras_functional_spec <- function(
126+
model_name,
127+
layer_blocks,
128+
mode = c("regression", "classification"),
129+
...,
130+
env = parent.frame()
131+
) {
132+
mode <- rlang::arg_match(mode)
133+
# 1. Argument Validation
134+
create_keras_spec_impl(
135+
model_name,
136+
layer_blocks,
137+
mode,
138+
functional = TRUE,
139+
env
140+
)
141+
}
Lines changed: 46 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,63 @@
1-
#' Create a Custom Keras Model Specification for Tidymodels
1+
#' Create a Custom Keras Sequential Model Specification for Tidymodels
22
#'
3+
#' @description
34
#' This function acts as a factory to generate a new `parsnip` model
4-
#' specification based on user-defined blocks of Keras layers. This allows for
5-
#' creating complex, tunable architectures that integrate seamlessly with the
6-
#' `tidymodels` ecosystem.
5+
#' specification based on user-defined blocks of Keras layers using the
6+
#' Sequential API. This is the ideal choice for creating models that are a
7+
#' simple, linear stack of layers. For models with complex, non-linear
8+
#' topologies, see [create_keras_functional_spec()].
79
#'
810
#' @param model_name A character string for the name of the new model
911
#' specification function (e.g., "custom_cnn"). This should be a valid R
1012
#' function name.
11-
#' @param layer_blocks A named list of functions. Each function defines a "block"
12-
#' of Keras layers. The function must take a Keras model object as its first
13-
#' argument and return the modified model. Other arguments to the function
14-
#' will become tunable parameters in the final model specification.
13+
#' @param layer_blocks A named, ordered list of functions. Each function defines
14+
#' a "block" of Keras layers. The function must take a Keras model object as
15+
#' its first argument and return the modified model. Other arguments to the
16+
#' function will become tunable parameters in the final model specification.
1517
#' @param mode A character string, either "regression" or "classification".
1618
#' @param ... Reserved for future use. Currently not used.
1719
#' @param env The environment in which to create the new model specification
1820
#' function and its associated `update()` method. Defaults to the calling
1921
#' environment (`parent.frame()`).
20-
#' @importFrom rlang enquos dots_list arg_match env_poke
21-
#' @importFrom parsnip update_dot_check
2222
#'
2323
#' @details
24-
#' The user is responsible for defining the entire model architecture by providing
25-
#' an ordered list of layer block functions.
26-
#' 1. The first block function must initialize the model (e.g., with
27-
#' \code{keras_model_sequential()}). It can accept an \code{input_shape} argument,
28-
#' which will be provided automatically by the fitting engine.
29-
#' 2. Subsequent blocks add hidden layers.
30-
#' 3. The final block should add the output layer. For classification, it can
31-
#' accept a \code{num_classes} argument, which is provided automatically.
32-
#'
33-
#' The \code{create_keras_spec()} function will inspect the arguments of your
34-
#' \code{layer_blocks} functions (ignoring \code{input_shape} and \code{num_classes})
35-
#' and make them available as arguments in the generated model specification,
36-
#' prefixed with the block's name (e.g.,
37-
#' `dense_units`).
38-
#'
39-
#' It also automatically creates arguments like `num_dense` to control how many
40-
#' times each block is repeated. In addition, common training parameters such as
41-
#' `epochs`, `learn_rate`, `validation_split`, and `verbose` are added to the
42-
#' specification.
24+
#' This function generates all the boilerplate needed to create a custom,
25+
#' tunable `parsnip` model specification that uses the Keras Sequential API.
26+
#'
27+
#' The function inspects the arguments of your `layer_blocks` functions
28+
#' (ignoring special arguments like `input_shape` and `num_classes`)
29+
#' and makes them available as arguments in the generated model specification,
30+
#' prefixed with the block's name (e.g., `dense_units`).
4331
#'
4432
#' The new model specification function and its `update()` method are created in
4533
#' the environment specified by the `env` argument.
4634
#'
35+
#' @section Model Architecture (Sequential API):
36+
#' `kerasnip` builds the model by applying the functions in `layer_blocks` in
37+
#' the order they are provided. Each function receives the Keras model built by
38+
#' the previous function and returns a modified version.
39+
#'
40+
#' 1. The **first block** must initialize the model (e.g., with
41+
#' `keras_model_sequential()`). It can accept an `input_shape` argument,
42+
#' which `kerasnip` will provide automatically during fitting.
43+
#' 2. **Subsequent blocks** add layers to the model.
44+
#' 3. The **final block** should add the output layer. For classification, it
45+
#' can accept a `num_classes` argument, which is provided automatically.
46+
#'
47+
#' A key feature of this function is the automatic creation of `num_{block_name}`
48+
#' arguments (e.g., `num_hidden`). This allows you to control how many times
49+
#' each block is repeated, making it easy to tune the depth of your network.
50+
#'
51+
#' @importFrom rlang enquos dots_list arg_match env_poke
52+
#' @importFrom parsnip update_dot_check
53+
#'
4754
#' @return Invisibly returns `NULL`. Its primary side effect is to create a new
48-
#' model specification function (e.g., `dynamic_mlp()`) in the specified
55+
#' model specification function (e.g., `my_mlp()`) in the specified
4956
#' environment and register the model with `parsnip` so it can be used within
5057
#' the `tidymodels` framework.
5158
#'
52-
#' @seealso [remove_keras_spec()], [parsnip::new_model_spec()]
59+
#' @seealso [remove_keras_spec()], [parsnip::new_model_spec()],
60+
#' [create_keras_functional_spec()]
5361
#'
5462
#' @export
5563
#' @examples
@@ -75,7 +83,7 @@
7583
#' }
7684
#'
7785
#' # 2. Create the spec, providing blocks in the correct order.
78-
#' create_keras_spec(
86+
#' create_keras_sequential_spec(
7987
#' model_name = "my_mlp",
8088
#' layer_blocks = list(
8189
#' input = input_block,
@@ -86,7 +94,7 @@
8694
#' )
8795
#'
8896
#' # 3. Use the newly created specification function!
89-
# Note the new arguments `num_hidden` and `hidden_units`.
97+
#' # Note the new arguments `num_hidden` and `hidden_units`.
9098
#' model_spec <- my_mlp(
9199
#' num_hidden = 2,
92100
#' hidden_units = 64,
@@ -97,28 +105,19 @@
97105
#' print(model_spec)
98106
#' }
99107
#' }
100-
create_keras_spec <- function(
108+
create_keras_sequential_spec <- function(
101109
model_name,
102110
layer_blocks,
103111
mode = c("regression", "classification"),
104112
...,
105113
env = parent.frame()
106114
) {
107115
mode <- arg_match(mode)
108-
args_info <- collect_spec_args(layer_blocks)
109-
spec_fun <- build_spec_function(
116+
create_keras_spec_impl(
110117
model_name,
118+
layer_blocks,
111119
mode,
112-
args_info$all_args,
113-
args_info$parsnip_names,
114-
layer_blocks
120+
functional = FALSE,
121+
env
115122
)
116-
117-
register_core_model(model_name, mode)
118-
register_model_args(model_name, args_info$parsnip_names)
119-
register_fit_predict(model_name, mode, layer_blocks)
120-
register_update_method(model_name, args_info$parsnip_names, env = env)
121-
122-
env_poke(env, model_name, spec_fun)
123-
invisible(NULL)
124123
}

0 commit comments

Comments
 (0)