|
1 | | -#' Create a Custom Keras Model Specification for Tidymodels |
| 1 | +#' Create a Custom Keras Sequential Model Specification for Tidymodels |
2 | 2 | #' |
| 3 | +#' @description |
3 | 4 | #' This function acts as a factory to generate a new `parsnip` model |
4 | | -#' specification based on user-defined blocks of Keras layers. This allows for |
5 | | -#' creating complex, tunable architectures that integrate seamlessly with the |
6 | | -#' `tidymodels` ecosystem. |
| 5 | +#' specification based on user-defined blocks of Keras layers using the |
| 6 | +#' Sequential API. This is the ideal choice for creating models that are a |
| 7 | +#' simple, linear stack of layers. For models with complex, non-linear |
| 8 | +#' topologies, see [create_keras_functional_spec()]. |
7 | 9 | #' |
8 | 10 | #' @param model_name A character string for the name of the new model |
9 | 11 | #' specification function (e.g., "custom_cnn"). This should be a valid R |
10 | 12 | #' function name. |
11 | | -#' @param layer_blocks A named list of functions. Each function defines a "block" |
12 | | -#' of Keras layers. The function must take a Keras model object as its first |
13 | | -#' argument and return the modified model. Other arguments to the function |
14 | | -#' will become tunable parameters in the final model specification. |
| 13 | +#' @param layer_blocks A named, ordered list of functions. Each function defines |
| 14 | +#' a "block" of Keras layers. The function must take a Keras model object as |
| 15 | +#' its first argument and return the modified model. Other arguments to the |
| 16 | +#' function will become tunable parameters in the final model specification. |
15 | 17 | #' @param mode A character string, either "regression" or "classification". |
16 | 18 | #' @param ... Reserved for future use. Currently not used. |
17 | 19 | #' @param env The environment in which to create the new model specification |
18 | 20 | #' function and its associated `update()` method. Defaults to the calling |
19 | 21 | #' environment (`parent.frame()`). |
20 | | -#' @importFrom rlang enquos dots_list arg_match env_poke |
21 | | -#' @importFrom parsnip update_dot_check |
22 | 22 | #' |
23 | 23 | #' @details |
24 | | -#' The user is responsible for defining the entire model architecture by providing |
25 | | -#' an ordered list of layer block functions. |
26 | | -#' 1. The first block function must initialize the model (e.g., with |
27 | | -#' \code{keras_model_sequential()}). It can accept an \code{input_shape} argument, |
28 | | -#' which will be provided automatically by the fitting engine. |
29 | | -#' 2. Subsequent blocks add hidden layers. |
30 | | -#' 3. The final block should add the output layer. For classification, it can |
31 | | -#' accept a \code{num_classes} argument, which is provided automatically. |
32 | | -#' |
33 | | -#' The \code{create_keras_spec()} function will inspect the arguments of your |
34 | | -#' \code{layer_blocks} functions (ignoring \code{input_shape} and \code{num_classes}) |
35 | | -#' and make them available as arguments in the generated model specification, |
36 | | -#' prefixed with the block's name (e.g., |
37 | | -#' `dense_units`). |
38 | | -#' |
39 | | -#' It also automatically creates arguments like `num_dense` to control how many |
40 | | -#' times each block is repeated. In addition, common training parameters such as |
41 | | -#' `epochs`, `learn_rate`, `validation_split`, and `verbose` are added to the |
42 | | -#' specification. |
| 24 | +#' This function generates all the boilerplate needed to create a custom, |
| 25 | +#' tunable `parsnip` model specification that uses the Keras Sequential API. |
| 26 | +#' |
| 27 | +#' The function inspects the arguments of your `layer_blocks` functions |
| 28 | +#' (ignoring special arguments like `input_shape` and `num_classes`) |
| 29 | +#' and makes them available as arguments in the generated model specification, |
| 30 | +#' prefixed with the block's name (e.g., `dense_units`). |
43 | 31 | #' |
44 | 32 | #' The new model specification function and its `update()` method are created in |
45 | 33 | #' the environment specified by the `env` argument. |
46 | 34 | #' |
| 35 | +#' @section Model Architecture (Sequential API): |
| 36 | +#' `kerasnip` builds the model by applying the functions in `layer_blocks` in |
| 37 | +#' the order they are provided. Each function receives the Keras model built by |
| 38 | +#' the previous function and returns a modified version. |
| 39 | +#' |
| 40 | +#' 1. The **first block** must initialize the model (e.g., with |
| 41 | +#' `keras_model_sequential()`). It can accept an `input_shape` argument, |
| 42 | +#' which `kerasnip` will provide automatically during fitting. |
| 43 | +#' 2. **Subsequent blocks** add layers to the model. |
| 44 | +#' 3. The **final block** should add the output layer. For classification, it |
| 45 | +#' can accept a `num_classes` argument, which is provided automatically. |
| 46 | +#' |
| 47 | +#' A key feature of this function is the automatic creation of `num_{block_name}` |
| 48 | +#' arguments (e.g., `num_hidden`). This allows you to control how many times |
| 49 | +#' each block is repeated, making it easy to tune the depth of your network. |
| 50 | +#' |
| 51 | +#' @importFrom rlang enquos dots_list arg_match env_poke |
| 52 | +#' @importFrom parsnip update_dot_check |
| 53 | +#' |
47 | 54 | #' @return Invisibly returns `NULL`. Its primary side effect is to create a new |
48 | | -#' model specification function (e.g., `dynamic_mlp()`) in the specified |
| 55 | +#' model specification function (e.g., `my_mlp()`) in the specified |
49 | 56 | #' environment and register the model with `parsnip` so it can be used within |
50 | 57 | #' the `tidymodels` framework. |
51 | 58 | #' |
52 | | -#' @seealso [remove_keras_spec()], [parsnip::new_model_spec()] |
| 59 | +#' @seealso [remove_keras_spec()], [parsnip::new_model_spec()], |
| 60 | +#' [create_keras_functional_spec()] |
53 | 61 | #' |
54 | 62 | #' @export |
55 | 63 | #' @examples |
|
75 | 83 | #' } |
76 | 84 | #' |
77 | 85 | #' # 2. Create the spec, providing blocks in the correct order. |
78 | | -#' create_keras_spec( |
| 86 | +#' create_keras_sequential_spec( |
79 | 87 | #' model_name = "my_mlp", |
80 | 88 | #' layer_blocks = list( |
81 | 89 | #' input = input_block, |
|
86 | 94 | #' ) |
87 | 95 | #' |
88 | 96 | #' # 3. Use the newly created specification function! |
89 | | -# Note the new arguments `num_hidden` and `hidden_units`. |
| 97 | +#' # Note the new arguments `num_hidden` and `hidden_units`. |
90 | 98 | #' model_spec <- my_mlp( |
91 | 99 | #' num_hidden = 2, |
92 | 100 | #' hidden_units = 64, |
|
97 | 105 | #' print(model_spec) |
98 | 106 | #' } |
99 | 107 | #' } |
100 | | -create_keras_spec <- function( |
| 108 | +create_keras_sequential_spec <- function( |
101 | 109 | model_name, |
102 | 110 | layer_blocks, |
103 | 111 | mode = c("regression", "classification"), |
104 | 112 | ..., |
105 | 113 | env = parent.frame() |
106 | 114 | ) { |
107 | 115 | mode <- arg_match(mode) |
108 | | - args_info <- collect_spec_args(layer_blocks) |
109 | | - spec_fun <- build_spec_function( |
| 116 | + create_keras_spec_impl( |
110 | 117 | model_name, |
| 118 | + layer_blocks, |
111 | 119 | mode, |
112 | | - args_info$all_args, |
113 | | - args_info$parsnip_names, |
114 | | - layer_blocks |
| 120 | + functional = FALSE, |
| 121 | + env |
115 | 122 | ) |
116 | | - |
117 | | - register_core_model(model_name, mode) |
118 | | - register_model_args(model_name, args_info$parsnip_names) |
119 | | - register_fit_predict(model_name, mode, layer_blocks) |
120 | | - register_update_method(model_name, args_info$parsnip_names, env = env) |
121 | | - |
122 | | - env_poke(env, model_name, spec_fun) |
123 | | - invisible(NULL) |
124 | 123 | } |
0 commit comments