Skip to content

Commit 81b4193

Browse files
committed
Adding core model specification generator
1 parent ca8805f commit 81b4193

File tree

8 files changed

+1195
-15
lines changed

8 files changed

+1195
-15
lines changed

R/create_keras_spec.R

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
#' Create a Custom Keras Model Specification for Tidymodels
2+
#'
3+
#' This function acts as a factory to generate a new `parsnip` model
4+
#' specification based on user-defined blocks of Keras layers. This allows for
5+
#' creating complex, tunable architectures that integrate seamlessly with the
6+
#' `tidymodels` ecosystem.
7+
#'
8+
#' @param model_name A character string for the name of the new model
9+
#' specification function (e.g., "custom_cnn"). This should be a valid R
10+
#' function name.
11+
#' @param layer_blocks A named list of functions. Each function defines a "block"
12+
#' of Keras layers. The function must take a Keras model object as its first
13+
#' argument and return the modified model. Other arguments to the function
14+
#' will become tunable parameters in the final model specification.
15+
#' @param mode A character string, either "regression" or "classification".
16+
#' @param ... Reserved for future use. Currently not used.
17+
#' @param env The environment in which to create the new model specification
18+
#' function and its associated `update()` method. Defaults to the calling
19+
#' environment (`parent.frame()`).
20+
#' @importFrom rlang enquos dots_list arg_match env_poke
21+
#' @importFrom parsnip update_dot_check
22+
#'
23+
#' @details
24+
#' The user is responsible for defining the entire model architecture by providing
25+
#' an ordered list of layer block functions.
26+
#' 1. The first block function must initialize the model (e.g., with
27+
#' \code{keras_model_sequential()}). It can accept an \code{input_shape} argument,
28+
#' which will be provided automatically by the fitting engine.
29+
#' 2. Subsequent blocks add hidden layers.
30+
#' 3. The final block should add the output layer. For classification, it can
31+
#' accept a \code{num_classes} argument, which is provided automatically.
32+
#'
33+
#' The \code{create_keras_spec()} function will inspect the arguments of your
34+
#' \code{layer_blocks} functions (ignoring \code{input_shape} and \code{num_classes})
35+
#' and make them available as arguments in the generated model specification,
36+
#' prefixed with the block's name (e.g.,
37+
#' `dense_units`).
38+
#'
39+
#' It also automatically creates arguments like `num_dense` to control how many
40+
#' times each block is repeated. In addition, common training parameters such as
41+
#' `epochs`, `learn_rate`, `validation_split`, and `verbose` are added to the
42+
#' specification.
43+
#'
44+
#' The new model specification function and its `update()` method are created in
45+
#' the environment specified by the `env` argument.
46+
#'
47+
#' @return Invisibly returns `NULL`. Its primary side effect is to create a new
48+
#' model specification function (e.g., `dynamic_mlp()`) in the specified
49+
#' environment and register the model with `parsnip` so it can be used within
50+
#' the `tidymodels` framework.
51+
#'
52+
#' @seealso [remove_keras_spec()], [parsnip::new_model_spec()]
53+
#'
54+
#' @export
55+
#' @examples
56+
#' \dontrun{
57+
#' if (requireNamespace("keras3", quietly = TRUE)) {
58+
#' library(keras3)
59+
#' library(parsnip)
60+
#' library(dials)
61+
#'
62+
#' # 1. Define layer blocks for a complete model.
63+
#' # The first block must initialize the model. `input_shape` is passed automatically.
64+
#' input_block <- function(model, input_shape) {
65+
#' keras_model_sequential(input_shape = input_shape)
66+
#' }
67+
#' # A block for hidden layers. `units` will become a tunable parameter.
68+
#' hidden_block <- function(model, units = 32) {
69+
#' model |> layer_dense(units = units, activation = "relu")
70+
#' }
71+
#'
72+
#' # The output block. `num_classes` is passed automatically for classification.
73+
#' output_block <- function(model, num_classes) {
74+
#' model |> layer_dense(units = num_classes, activation = "softmax")
75+
#' }
76+
#'
77+
#' # 2. Create the spec, providing blocks in the correct order.
78+
#' create_keras_spec(
79+
#' model_name = "my_mlp",
80+
#' layer_blocks = list(
81+
#' input = input_block,
82+
#' hidden = hidden_block,
83+
#' output = output_block
84+
#' ),
85+
#' mode = "classification"
86+
#' )
87+
#'
88+
#' # 3. Use the newly created specification function!
89+
# Note the new arguments `num_hidden` and `hidden_units`.
90+
#' model_spec <- my_mlp(
91+
#' num_hidden = 2,
92+
#' hidden_units = 64,
93+
#' epochs = 10,
94+
#' learn_rate = 0.01
95+
#' )
96+
#'
97+
#' print(model_spec)
98+
#' }
99+
#' }
100+
create_keras_spec <- function(
101+
model_name,
102+
layer_blocks,
103+
mode = c("regression", "classification"),
104+
...,
105+
env = parent.frame()
106+
) {
107+
mode <- arg_match(mode)
108+
args_info <- collect_spec_args(layer_blocks)
109+
spec_fun <- build_spec_function(
110+
model_name,
111+
mode,
112+
args_info$all_args,
113+
args_info$parsnip_names,
114+
layer_blocks
115+
)
116+
117+
register_core_model(model_name, mode)
118+
register_model_args(model_name, args_info$parsnip_names)
119+
register_fit_predict(model_name, mode, layer_blocks)
120+
register_update_method(model_name, args_info$parsnip_names)
121+
122+
env_poke(env, model_name, spec_fun)
123+
invisible(NULL)
124+
}

0 commit comments

Comments
 (0)