Skip to content

Commit 22065f4

Browse files
committed
Fixing founded issues with compile_ and fit_ args, and other general improvements
1 parent 6b491a5 commit 22065f4

15 files changed

+221
-145
lines changed

R/create_keras_spec_helpers.R

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#' (assumed to be the `model` object), to find tunable hyperparameters.
1313
#'
1414
#' For **functional models** (`functional = TRUE`):
15-
#' - It does **not** create `num_{block_name}` arguments.
15+
#' - It creates `num_{block_name}` arguments to control block repetition.
1616
#' - It inspects the arguments of each block function. Arguments whose names
1717
#' match other block names are considered graph connections (inputs) and are
1818
#' ignored. The remaining arguments are treated as tunable hyperparameters.
@@ -30,21 +30,11 @@
3030
#' @noRd
3131
collect_spec_args <- function(
3232
layer_blocks,
33-
functional,
34-
global_args = c(
35-
"epochs",
36-
"batch_size",
37-
"learn_rate",
38-
"validation_split",
39-
"verbose",
40-
"compile_loss",
41-
"compile_optimizer",
42-
"compile_metrics"
43-
)
33+
functional
4434
) {
45-
if (any(c("compile", "optimizer") %in% names(layer_blocks))) {
35+
if (any(c("compile", "fit", "optimizer") %in% names(layer_blocks))) {
4636
stop(
47-
"`compile` and `optimizer` are protected names and cannot be used as layer block names.",
37+
"`compile`, `fit` and `optimizer` are protected names and cannot be used as layer block names.",
4838
call. = FALSE
4939
)
5040
}
@@ -90,8 +80,25 @@ collect_spec_args <- function(
9080
}
9181
}
9282

93-
# global training parameters
94-
for (g in global_args) {
83+
# Add global training and compile parameters dynamically
84+
# These are discovered from keras3::fit and keras3::compile in zzz.R
85+
fit_params <- if (length(keras_fit_arg_names) > 0) {
86+
paste0("fit_", keras_fit_arg_names)
87+
} else {
88+
character()
89+
}
90+
compile_params <- if (length(keras_compile_arg_names) > 0) {
91+
paste0("compile_", keras_compile_arg_names)
92+
} else {
93+
character()
94+
}
95+
96+
# learn_rate is a special convenience argument for the default optimizer
97+
special_params <- "learn_rate"
98+
99+
dynamic_global_args <- c(special_params, fit_params, compile_params)
100+
101+
for (g in dynamic_global_args) {
95102
all_args[[g]] <- rlang::zap()
96103
parsnip_names <- c(parsnip_names, g)
97104
}

R/generate_roxygen_docs.R

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -70,20 +70,22 @@ generate_roxygen_docs <- function(
7070

7171
# Group args for structured documentation
7272
num_params <- arg_names[startsWith(arg_names, "num_")]
73+
fit_params <- arg_names[startsWith(arg_names, "fit_")]
7374
compile_params <- arg_names[startsWith(arg_names, "compile_")]
74-
global_params <- c(
75-
"epochs",
76-
"batch_size",
77-
"learn_rate",
78-
"validation_split",
79-
"verbose"
80-
)
75+
# `learn_rate` is a special top-level convenience argument
76+
special_params <- "learn_rate"
77+
8178
block_params <- setdiff(
8279
arg_names,
83-
c(num_params, compile_params, global_params)
80+
c(num_params, fit_params, compile_params, special_params)
8481
)
8582

8683
# Document block-specific params
84+
if ("learn_rate" %in% block_params) {
85+
# This can happen if a user names a block `learn` and it has a `rate` param.
86+
# It's an edge case, but we should not document it twice.
87+
block_params <- setdiff(block_params, "learn_rate")
88+
}
8789
if (length(block_params) > 0) {
8890
param_docs <- c(
8991
param_docs,
@@ -135,40 +137,46 @@ generate_roxygen_docs <- function(
135137
)
136138
}
137139

138-
# Document global params
139-
global_param_desc <- list(
140-
epochs = "The total number of iterations to train the model.",
141-
batch_size = "The number of samples per gradient update.",
142-
learn_rate = "The learning rate for the default Adam optimizer. This is ignored if `compile_optimizer` is provided as a pre-built object.",
143-
validation_split = "The proportion of the training data to be used as a validation set.",
144-
verbose = "The level of verbosity for model fitting (0, 1, or 2)."
145-
)
140+
# Document special `learn_rate` param
146141
param_docs <- c(
147142
param_docs,
148-
purrr::map_chr(global_params, function(p) {
149-
paste0("@param ", p, " ", global_param_desc[[p]])
150-
})
143+
"@param learn_rate The learning rate for the default Adam optimizer. This is ignored if `compile_optimizer` is provided as a pre-built Keras optimizer object."
151144
)
152145

153146
# Document compile params
154-
compile_param_desc <- list(
155-
compile_loss = "The loss function for compiling the model. Can be a string (e.g., 'mse') or a Keras loss object. Overrides the default.",
156-
compile_optimizer = "The optimizer for compiling the model. Can be a string (e.g., 'sgd') or a Keras optimizer object. Overrides the default.",
157-
compile_metrics = "A character vector of metrics to monitor during training (e.g., `c('mae', 'mse')`). Overrides the default."
158-
)
159-
param_docs <- c(
160-
param_docs,
161-
purrr::map_chr(compile_params, function(p) {
162-
paste0("@param ", p, " ", compile_param_desc[[p]])
163-
})
164-
)
147+
if (length(compile_params) > 0) {
148+
param_docs <- c(
149+
param_docs,
150+
purrr::map_chr(compile_params, function(p) {
151+
paste0(
152+
"@param ",
153+
p,
154+
" Argument to `keras3::compile()`. See the 'Model Compilation' section."
155+
)
156+
})
157+
)
158+
}
159+
160+
# Document fit params
161+
if (length(fit_params) > 0) {
162+
param_docs <- c(
163+
param_docs,
164+
purrr::map_chr(fit_params, function(p) {
165+
paste0(
166+
"@param ",
167+
p,
168+
" Argument to `keras3::fit()`. See the 'Model Fitting' section."
169+
)
170+
})
171+
)
172+
}
165173

166174
# Add ... param
167175
param_docs <- c(
168176
param_docs,
169177
paste0(
170-
"@param ... Additional arguments passed to the Keras engine. This is commonly used for arguments to `keras3::fit()` (prefixed with `fit_`). ",
171-
"See the 'Model Fitting' and 'Model Compilation' sections for details."
178+
"@param ... Additional arguments passed to the Keras engine. Use this for arguments to `keras3::fit()` or `keras3::compile()` ",
179+
"that are not exposed as top-level arguments."
172180
)
173181
)
174182

@@ -178,7 +186,8 @@ generate_roxygen_docs <- function(
178186
"#' @section Model Architecture (Functional API):",
179187
"#' The Keras model is constructed using the Functional API. Each layer block function's arguments",
180188
"#' determine its inputs. For example, a block `function(input_a, input_b, ...)` will be connected",
181-
"#' to the outputs of the `input_a` and `input_b` blocks.",
189+
"#' to the outputs of the `input_a` and `input_b` blocks. You can also repeat a block by setting",
190+
"#' the `num_{block_name}` argument, provided the block has a single input tensor.",
182191
"#' The first block in `layer_blocks` is assumed to be the input layer and should not have inputs from other layers."
183192
)
184193
see_also_fit <- "generic_functional_fit()"
@@ -213,7 +222,7 @@ generate_roxygen_docs <- function(
213222
"#' @section Model Fitting:",
214223
"#' The model is fit using `keras3::fit()`. You can pass any argument to this function by prefixing it with `fit_`.",
215224
"#' For example, to add Keras callbacks, you can pass `fit_callbacks = list(callback_early_stopping())`.",
216-
"#' The `epochs` and `batch_size` arguments are also passed to `fit()`."
225+
"#' Common arguments include `fit_epochs`, `fit_batch_size`, and `fit_validation_split`."
217226
)
218227

219228
# Other tags

R/generic_fit_helpers.R

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
#' list of arguments, resolves them, and combines them with defaults.
66
#'
77
#' @details
8-
#' It handles the special logic for the `optimizer`, where a string name is
9-
#' resolved to a Keras optimizer object, applying the `learn_rate` if necessary.
10-
#' It also resolves string names for `loss` and `metrics` using `get_keras_object()`.
8+
#' This function orchestrates the compilation setup. It gives precedence to
9+
#' user-provided arguments (e.g., `compile_optimizer`) over the mode-based
10+
#' defaults. It handles the special logic for the `optimizer`, where a string
11+
#' name (e.g., `"sgd"`) is resolved to a Keras optimizer object, applying the
12+
#' top-level `learn_rate` if necessary. It also resolves string names for `loss`
13+
#' and `metrics` using `get_keras_object()`.
1114
#'
1215
#' @param all_args The list of all arguments passed to the fitting function's `...`.
13-
#' @param learn_rate The main `learn_rate` parameter.
16+
#' @param learn_rate The top-level `learn_rate` parameter.
1417
#' @param default_loss The default loss function to use if not provided.
1518
#' @param default_metrics The default metric(s) to use if not provided.
1619
#' @return A named list of arguments ready to be passed to `keras3::compile()`.
@@ -69,6 +72,14 @@ collect_compile_args <- function(
6972
!names(user_compile_args) %in% c("optimizer", "loss", "metrics")
7073
]
7174
final_compile_args <- c(final_compile_args, other_args)
75+
# Filter out arguments that are NULL or rlang_zap before passing to keras3::compile
76+
final_compile_args <- final_compile_args[
77+
!vapply(
78+
final_compile_args,
79+
function(x) inherits(x, "rlang_zap"),
80+
logical(1)
81+
)
82+
]
7283
final_compile_args
7384
}
7485

@@ -78,21 +89,22 @@ collect_compile_args <- function(
7889
#' This internal helper extracts all arguments prefixed with `fit_` from a list
7990
#' of arguments and combines them with the core arguments for `keras3::fit()`.
8091
#'
92+
#' @details
93+
#' It constructs the final list of arguments for `keras3::fit()`. It starts with
94+
#' the required data (`x`, `y`) and the `verbose` setting. It then merges any
95+
#' user-provided arguments from the model specification (e.g., `fit_epochs`,
96+
#' `fit_callbacks`), with the user-provided arguments taking precedence over
97+
#' any defaults.
98+
#'
8199
#' @param x_proc The processed predictor data.
82100
#' @param y_mat The processed outcome data.
83-
#' @param epochs The number of epochs.
84-
#' @param batch_size The batch size.
85-
#' @param validation_split The validation split proportion.
86101
#' @param verbose The verbosity level.
87102
#' @param all_args The list of all arguments passed to the fitting function's `...`.
88103
#' @return A named list of arguments ready to be passed to `keras3::fit()`.
89104
#' @noRd
90105
collect_fit_args <- function(
91106
x_proc,
92107
y_mat,
93-
epochs,
94-
batch_size,
95-
validation_split,
96108
verbose,
97109
all_args
98110
) {
@@ -101,15 +113,24 @@ collect_fit_args <- function(
101113
user_fit_args <- all_args[fit_arg_names]
102114
names(user_fit_args) <- sub("^fit_", "", names(user_fit_args))
103115

104-
final_fit_args <- c(
105-
list(
106-
x = x_proc,
107-
y = y_mat,
108-
epochs = epochs,
109-
batch_size = batch_size,
110-
validation_split = validation_split,
111-
verbose = verbose
112-
),
113-
user_fit_args
116+
# Build the core argument set. `verbose` can be overridden by `fit_verbose`.
117+
base_args <- list(
118+
x = x_proc,
119+
y = y_mat,
120+
verbose = verbose
114121
)
122+
123+
merged_args <- utils::modifyList(base_args, user_fit_args)
124+
125+
# Filter out arguments that are NULL or rlang_zap before passing to keras3::fit
126+
merged_args <- merged_args[
127+
!vapply(
128+
merged_args,
129+
function(x) {
130+
inherits(x, "rlang_zap")
131+
},
132+
logical(1)
133+
)
134+
]
135+
merged_args
115136
}

R/generic_sequential_fit.R

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -71,28 +71,13 @@ generic_sequential_fit <- function(
7171
x,
7272
y,
7373
layer_blocks,
74-
epochs = 10,
75-
batch_size = 32,
76-
learn_rate = 0.01,
77-
validation_split = 0.2,
78-
verbose = 0,
7974
...
8075
) {
81-
# --- 0. Resolve arguments ---
82-
# Parsnip passes "zapped" arguments for user-unspecified args.
83-
# This helper replaces them with the function's defaults.
84-
resolve_default <- function(x, default) {
85-
if (inherits(x, "rlang_zap")) default else x
86-
}
87-
fmls <- rlang::fn_fmls(sys.function())
88-
epochs <- resolve_default(epochs, fmls$epochs)
89-
batch_size <- resolve_default(batch_size, fmls$batch_size)
90-
learn_rate <- resolve_default(learn_rate, fmls$learn_rate)
91-
validation_split <- resolve_default(validation_split, fmls$validation_split)
92-
verbose <- resolve_default(verbose, fmls$verbose)
93-
94-
# --- 1. Data & Input Shape Preparation ---
76+
# --- 0. Argument & Data Preparation ---
9577
all_args <- list(...)
78+
learn_rate <- all_args$learn_rate %||% 0.01
79+
verbose <- all_args$verbose %||% 0
80+
9681
# Handle both standard tabular data (matrix) and list-columns of arrays
9782
# (for images/sequences) that come from recipes.
9883
if (is.data.frame(x) && ncol(x) == 1 && is.list(x[[1]])) {
@@ -188,9 +173,6 @@ generic_sequential_fit <- function(
188173
fit_args <- collect_fit_args(
189174
x_proc,
190175
y_mat,
191-
epochs,
192-
batch_size,
193-
validation_split,
194176
verbose,
195177
all_args
196178
)

R/globals.R

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
11
utils::globalVariables(
2-
c("object", "new_data", "engine", "fresh", "parameters")
2+
c(
3+
"object",
4+
"new_data",
5+
"engine",
6+
"fresh",
7+
"parameters",
8+
"keras_fit_arg_names",
9+
"keras_compile_arg_names"
10+
)
311
)

R/register_model_args.R

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
#' \item Other arguments are mapped based on their suffix (e.g., `dense_units`
1818
#' is mapped based on `units`). The internal `keras_dials_map` object
1919
#' contains common mappings like `units` -> `dials::hidden_units()`.
20-
#' \item Arguments for `compile_loss` and `compile_optimizer` are mapped to
21-
#' custom `dials` parameter functions within `kerasnip`.
20+
#' \item Arguments for `compile_loss` and `compile_optimizer` are mapped to custom
21+
#' `dials` parameter functions (`loss_function_keras()` and `optimizer_function()`)
22+
#' that are part of the `kerasnip` package itself. The function correctly
23+
#' sets the `pkg` for these to `kerasnip`.
2224
#' }
2325
#'
2426
#' @param model_name The name of the new model specification.
@@ -43,9 +45,9 @@ register_model_args <- function(model_name, parsnip_names) {
4345
"dropout",
4446
"learn_rate",
4547
"learn_rate",
48+
"fit_epochs",
4649
"epochs",
47-
"epochs",
48-
"batch_size",
50+
"fit_batch_size",
4951
"batch_size",
5052
"compile_loss", # parsnip arg
5153
"loss_function_keras", # dials function from kerasnip
@@ -54,7 +56,7 @@ register_model_args <- function(model_name, parsnip_names) {
5456
)
5557

5658
# We now allow optimizer to be tuned. Metrics are for tracking, not training.
57-
non_tunable <- c("verbose")
59+
non_tunable <- c("fit_verbose")
5860

5961
for (arg in parsnip_names) {
6062
if (arg %in% non_tunable) {
File renamed without changes.

0 commit comments

Comments
 (0)