|
| 1 | +#' Create a Custom Keras Model Specification for Tidymodels |
| 2 | +#' |
| 3 | +#' This function acts as a factory to generate a new `parsnip` model |
| 4 | +#' specification based on user-defined blocks of Keras layers. This allows for |
| 5 | +#' creating complex, tunable architectures that integrate seamlessly with the |
| 6 | +#' `tidymodels` ecosystem. |
| 7 | +#' |
| 8 | +#' @param model_name A character string for the name of the new model |
| 9 | +#' specification function (e.g., "custom_cnn"). This should be a valid R |
| 10 | +#' function name. |
| 11 | +#' @param layer_blocks A named list of functions. Each function defines a "block" |
| 12 | +#' of Keras layers. The function must take a Keras model object as its first |
| 13 | +#' argument and return the modified model. Other arguments to the function |
| 14 | +#' will become tunable parameters in the final model specification. |
| 15 | +#' @param mode A character string, either "regression" or "classification". |
| 16 | +#' @param ... Reserved for future use. Currently not used. |
| 17 | +#' @param env The environment in which to create the new model specification |
| 18 | +#' function and its associated `update()` method. Defaults to the calling |
| 19 | +#' environment (`parent.frame()`). |
| 20 | +#' @importFrom rlang enquos dots_list arg_match env_poke |
| 21 | +#' @importFrom parsnip update_dot_check |
| 22 | +#' |
| 23 | +#' @details |
| 24 | +#' The user is responsible for defining the entire model architecture by providing |
| 25 | +#' an ordered list of layer block functions. |
| 26 | +#' 1. The first block function must initialize the model (e.g., with |
| 27 | +#' \code{keras_model_sequential()}). It can accept an \code{input_shape} argument, |
| 28 | +#' which will be provided automatically by the fitting engine. |
| 29 | +#' 2. Subsequent blocks add hidden layers. |
| 30 | +#' 3. The final block should add the output layer. For classification, it can |
| 31 | +#' accept a \code{num_classes} argument, which is provided automatically. |
| 32 | +#' |
| 33 | +#' The \code{create_keras_spec()} function will inspect the arguments of your |
| 34 | +#' \code{layer_blocks} functions (ignoring \code{input_shape} and \code{num_classes}) |
| 35 | +#' and make them available as arguments in the generated model specification, |
| 36 | +#' prefixed with the block's name (e.g., |
| 37 | +#' `dense_units`). |
| 38 | +#' |
| 39 | +#' It also automatically creates arguments like `num_dense` to control how many |
| 40 | +#' times each block is repeated. In addition, common training parameters such as |
| 41 | +#' `epochs`, `learn_rate`, `validation_split`, and `verbose` are added to the |
| 42 | +#' specification. |
| 43 | +#' |
| 44 | +#' The new model specification function and its `update()` method are created in |
| 45 | +#' the environment specified by the `env` argument. |
| 46 | +#' |
| 47 | +#' @return Invisibly returns `NULL`. Its primary side effect is to create a new |
| 48 | +#' model specification function (e.g., `dynamic_mlp()`) in the specified |
| 49 | +#' environment and register the model with `parsnip` so it can be used within |
| 50 | +#' the `tidymodels` framework. |
| 51 | +#' |
| 52 | +#' @seealso [remove_keras_spec()], [parsnip::new_model_spec()] |
| 53 | +#' |
| 54 | +#' @export |
| 55 | +#' @examples |
| 56 | +#' \dontrun{ |
| 57 | +#' if (requireNamespace("keras3", quietly = TRUE)) { |
| 58 | +#' library(keras3) |
| 59 | +#' library(parsnip) |
| 60 | +#' library(dials) |
| 61 | +#' |
| 62 | +#' # 1. Define layer blocks for a complete model. |
| 63 | +#' # The first block must initialize the model. `input_shape` is passed automatically. |
| 64 | +#' input_block <- function(model, input_shape) { |
| 65 | +#' keras_model_sequential(input_shape = input_shape) |
| 66 | +#' } |
| 67 | +#' # A block for hidden layers. `units` will become a tunable parameter. |
| 68 | +#' hidden_block <- function(model, units = 32) { |
| 69 | +#' model |> layer_dense(units = units, activation = "relu") |
| 70 | +#' } |
| 71 | +#' |
| 72 | +#' # The output block. `num_classes` is passed automatically for classification. |
| 73 | +#' output_block <- function(model, num_classes) { |
| 74 | +#' model |> layer_dense(units = num_classes, activation = "softmax") |
| 75 | +#' } |
| 76 | +#' |
| 77 | +#' # 2. Create the spec, providing blocks in the correct order. |
| 78 | +#' create_keras_spec( |
| 79 | +#' model_name = "my_mlp", |
| 80 | +#' layer_blocks = list( |
| 81 | +#' input = input_block, |
| 82 | +#' hidden = hidden_block, |
| 83 | +#' output = output_block |
| 84 | +#' ), |
| 85 | +#' mode = "classification" |
| 86 | +#' ) |
| 87 | +#' |
| 88 | +#' # 3. Use the newly created specification function! |
| 89 | +# Note the new arguments `num_hidden` and `hidden_units`. |
| 90 | +#' model_spec <- my_mlp( |
| 91 | +#' num_hidden = 2, |
| 92 | +#' hidden_units = 64, |
| 93 | +#' epochs = 10, |
| 94 | +#' learn_rate = 0.01 |
| 95 | +#' ) |
| 96 | +#' |
| 97 | +#' print(model_spec) |
| 98 | +#' } |
| 99 | +#' } |
| 100 | +create_keras_spec <- function( |
| 101 | + model_name, |
| 102 | + layer_blocks, |
| 103 | + mode = c("regression", "classification"), |
| 104 | + ..., |
| 105 | + env = parent.frame() |
| 106 | +) { |
| 107 | + mode <- arg_match(mode) |
| 108 | + args_info <- collect_spec_args(layer_blocks) |
| 109 | + spec_fun <- build_spec_function( |
| 110 | + model_name, |
| 111 | + mode, |
| 112 | + args_info$all_args, |
| 113 | + args_info$parsnip_names, |
| 114 | + layer_blocks |
| 115 | + ) |
| 116 | + |
| 117 | + register_core_model(model_name, mode) |
| 118 | + register_model_args(model_name, args_info$parsnip_names) |
| 119 | + register_fit_predict(model_name, mode, layer_blocks) |
| 120 | + register_update_method(model_name, args_info$parsnip_names) |
| 121 | + |
| 122 | + env_poke(env, model_name, spec_fun) |
| 123 | + invisible(NULL) |
| 124 | +} |
0 commit comments