|
| 1 | +#' Create a Custom Keras Functional API Model Specification for Tidymodels |
| 2 | +#' |
| 3 | +#' This function acts as a factory to generate a new `parsnip` model |
| 4 | +#' specification based on user-defined blocks of Keras layers using the |
| 5 | +#' Functional API. This allows for creating complex, tunable architectures |
| 6 | +#' with non-linear topologies that integrate seamlessly with the `tidymodels` |
| 7 | +#' ecosystem. |
| 8 | +#' |
| 9 | +#' @param model_name A character string for the name of the new model |
| 10 | +#' specification function (e.g., "custom_resnet"). This should be a valid R |
| 11 | +#' function name. |
| 12 | +#' @param layer_blocks A named list of functions where each function defines a |
| 13 | +#' "block" (a node) in the model graph. The list names are crucial as they |
| 14 | +#' define the names of the nodes. The arguments of each function define how |
| 15 | +#' the nodes are connected. See the "Model Graph Connectivity" section for |
| 16 | +#' details. |
| 17 | +#' @param mode A character string, either "regression" or "classification". |
| 18 | +#' @param ... Reserved for future use. Currently not used. |
| 19 | +#' @param env The environment in which to create the new model specification |
| 20 | +#' function and its associated `update()` method. Defaults to the calling |
| 21 | +#' environment (`parent.frame()`). |
| 22 | +#' |
| 23 | +#' @details |
| 24 | +#' This function generates all the boilerplate needed to create a custom, |
| 25 | +#' tunable `parsnip` model specification that uses the Keras Functional API. |
| 26 | +#' This is ideal for models with complex, non-linear topologies, such as |
| 27 | +#' networks with multiple inputs/outputs or residual connections. |
| 28 | +#' |
| 29 | +#' The function inspects the arguments of your `layer_blocks` functions and |
| 30 | +#' makes them available as tunable parameters in the generated model |
| 31 | +#' specification, prefixed with the block's name (e.g., `dense_units`). |
| 32 | +#' Common training parameters such as `epochs` and `learn_rate` are also added. |
| 33 | +#' |
| 34 | +#' @section Model Graph Connectivity: |
| 35 | +#' `kerasnip` builds the model's directed acyclic graph by inspecting the |
| 36 | +#' arguments of each function in the `layer_blocks` list. The connection logic |
| 37 | +#' is as follows: |
| 38 | +#' |
| 39 | +#' 1. The **names of the elements** in the `layer_blocks` list define the names |
| 40 | +#' of the nodes in your graph (e.g., `main_input`, `dense_path`, `output`). |
| 41 | +#' 2. The **names of the arguments** in each block function specify its inputs. |
| 42 | +#' A block function like `my_block <- function(input_a, input_b, ...)` |
| 43 | +#' declares that it needs input from the nodes named `input_a` and `input_b`. |
| 44 | +#' `kerasnip` will automatically supply the output tensors from those nodes |
| 45 | +#' when calling `my_block`. |
| 46 | +#' |
| 47 | +#' There are two special requirements: |
| 48 | +#' * **Input Block**: The first block in the list is treated as the input |
| 49 | +#' node. Its function should not take other blocks as input, but it can have |
| 50 | +#' an `input_shape` argument, which is supplied automatically during fitting. |
| 51 | +#' * **Output Block**: Exactly one block must be named `"output"`. The tensor |
| 52 | +#' returned by this block is used as the final output of the Keras model. |
| 53 | +#' |
| 54 | +#' A key feature is the automatic creation of `num_{block_name}` arguments |
| 55 | +#' (e.g., `num_dense_path`). This allows you to control how many times a block |
| 56 | +#' is repeated, making it easy to tune the depth of your network. A block can |
| 57 | +#' only be repeated if it has exactly one input from another block in the graph. |
| 58 | +#' |
| 59 | +#' The new model specification function and its `update()` method are created |
| 60 | +#' in the environment specified by the `env` argument. |
| 61 | +#' |
| 62 | +#' @importFrom rlang enquos dots_list arg_match env_poke |
| 63 | +#' @importFrom parsnip update_dot_check |
| 64 | +#' |
| 65 | +#' @return Invisibly returns `NULL`. Its primary side effect is to create a |
| 66 | +#' new model specification function (e.g., `custom_resnet()`) in the |
| 67 | +#' specified environment and register the model with `parsnip` so it can be |
| 68 | +#' used within the `tidymodels` framework. |
| 69 | +#' |
| 70 | +#' @seealso [remove_keras_spec()], [parsnip::new_model_spec()], |
| 71 | +#' [create_keras_sequential_spec()] |
| 72 | +#' |
| 73 | +#' @export |
| 74 | +#' @examples |
| 75 | +#' \dontrun{ |
| 76 | +#' if (requireNamespace("keras3", quietly = TRUE)) { |
| 77 | +#' library(keras3) |
| 78 | +#' library(parsnip) |
| 79 | +#' |
| 80 | +#' # 1. Define block functions. These are the building blocks of our model. |
| 81 | +#' # An input block that receives the data's shape automatically. |
| 82 | +#' input_block <- function(input_shape) layer_input(shape = input_shape) |
| 83 | +#' |
| 84 | +#' # A dense block with a tunable `units` parameter. |
| 85 | +#' dense_block <- function(tensor, units) { |
| 86 | +#' tensor |> layer_dense(units = units, activation = "relu") |
| 87 | +#' } |
| 88 | +#' |
| 89 | +#' # A block that adds two tensors together (for the residual connection). |
| 90 | +#' add_block <- function(input_a, input_b) layer_add(list(input_a, input_b)) |
| 91 | +#' |
| 92 | +#' # An output block for regression. |
| 93 | +#' output_block_reg <- function(tensor) layer_dense(tensor, units = 1) |
| 94 | +#' |
| 95 | +#' # 2. Create the spec. The `layer_blocks` list defines the graph. |
| 96 | +#' create_keras_functional_spec( |
| 97 | +#' model_name = "my_resnet_spec", |
| 98 | +#' layer_blocks = list( |
| 99 | +#' # The names of list elements are the node names. |
| 100 | +#' main_input = input_block, |
| 101 | +#' |
| 102 | +#' # The argument `main_input` connects this block to the input node. |
| 103 | +#' dense_path = function(main_input, units = 32) dense_block(main_input, units), |
| 104 | +#' |
| 105 | +#' # This block's arguments connect it to the original input AND the dense layer. |
| 106 | +#' add_residual = function(main_input, dense_path) add_block(main_input, dense_path), |
| 107 | +#' |
| 108 | +#' # This block must be named 'output'. It connects to the residual add layer. |
| 109 | +#' output = function(add_residual) output_block_reg(add_residual) |
| 110 | +#' ), |
| 111 | +#' mode = "regression" |
| 112 | +#' ) |
| 113 | +#' |
| 114 | +#' # 3. Use the newly created specification function! |
| 115 | +#' # The `dense_path_units` argument was created automatically. |
| 116 | +#' model_spec <- my_resnet_spec(dense_path_units = 64, epochs = 10) |
| 117 | +#' |
| 118 | +#' # You could also tune the number of dense layers since it has a single input: |
| 119 | +#' # model_spec <- my_resnet_spec(num_dense_path = 2, dense_path_units = 32) |
| 120 | +#' |
| 121 | +#' print(model_spec) |
| 122 | +#' # tune::tunable(model_spec) |
| 123 | +#' } |
| 124 | +#' } |
| 125 | +create_keras_functional_spec <- function( |
| 126 | + model_name, |
| 127 | + layer_blocks, |
| 128 | + mode = c("regression", "classification"), |
| 129 | + ..., |
| 130 | + env = parent.frame() |
| 131 | +) { |
| 132 | + mode <- rlang::arg_match(mode) |
| 133 | + # 1. Argument Validation |
| 134 | + create_keras_spec_impl( |
| 135 | + model_name, |
| 136 | + layer_blocks, |
| 137 | + mode, |
| 138 | + functional = TRUE, |
| 139 | + env |
| 140 | + ) |
| 141 | +} |
0 commit comments